diff --git a/Sources/Starknet/Accounts/StarknetAccount.swift b/Sources/Starknet/Accounts/StarknetAccount.swift index fb6af77f5..98152bc7f 100644 --- a/Sources/Starknet/Accounts/StarknetAccount.swift +++ b/Sources/Starknet/Accounts/StarknetAccount.swift @@ -5,8 +5,14 @@ public enum StarknetAccountError: Error { case invalidResponse } +public enum CairoVersion: String, Encodable { + case zero + case one +} + public class StarknetAccount: StarknetAccountProtocol { private let version = Felt.one + private let cairoVersion: CairoVersion private var estimateVersion: Felt { Felt(BigUInt(2).power(128).advanced(by: BigInt(version.value)))! @@ -17,10 +23,11 @@ public class StarknetAccount: StarknetAccountProtocol { private let signer: StarknetSignerProtocol private let provider: StarknetProviderProtocol - public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol) { + public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol, cairoVersion: CairoVersion) { self.address = address self.signer = signer self.provider = provider + self.cairoVersion = cairoVersion } private func makeSequencerInvokeTransaction(calldata: StarknetCalldata, signature: StarknetSignature, params: StarknetExecutionParams, version: Felt) -> StarknetSequencerInvokeTransaction { @@ -41,7 +48,7 @@ public class StarknetAccount: StarknetAccountProtocol { public func sign(calls: [StarknetCall], params: StarknetExecutionParams, forFeeEstimation: Bool) throws -> StarknetSequencerInvokeTransaction { let version = forFeeEstimation ? estimateVersion : version - let calldata = starknetCallsToExecuteCalldata(calls: calls) + let calldata = starknetCallsToExecuteCalldata(calls: calls, cairoVersion: cairoVersion) let sequencerTransaction = makeSequencerInvokeTransaction(calldata: calldata, signature: [], params: params, version: version) diff --git a/Sources/Starknet/Data/Execution.swift b/Sources/Starknet/Data/Execution.swift index 6a71b5522..38f9a1892 100644 --- a/Sources/Starknet/Data/Execution.swift +++ b/Sources/Starknet/Data/Execution.swift @@ -41,7 +41,16 @@ public struct StarknetOptionalExecutionParams { } } -public func starknetCallsToExecuteCalldata(calls: [StarknetCall]) -> [Felt] { +public func starknetCallsToExecuteCalldata(calls: [StarknetCall], cairoVersion: CairoVersion) -> [Felt] { + switch cairoVersion { + case .zero: + return starknetCallsToExecuteCalldataCairo0(calls: calls) + case .one: + return starknetCallsToExecuteCalldataCairo1(calls: calls) + } +} + +private func starknetCallsToExecuteCalldataCairo0(calls: [StarknetCall]) -> [Felt] { var wholeCalldata: [Felt] = [] var callArray: [Felt] = [] @@ -56,3 +65,18 @@ public func starknetCallsToExecuteCalldata(calls: [StarknetCall]) -> [Felt] { return [Felt(calls.count)!] + callArray + [Felt(wholeCalldata.count)!] + wholeCalldata } + +private func starknetCallsToExecuteCalldataCairo1(calls: [StarknetCall]) -> [Felt] { + var callArray: [Felt] = [] + + callArray.append(Felt(calls.count)!) + + calls.forEach { call in + callArray.append(call.contractAddress) + callArray.append(call.entrypoint) + callArray.append(Felt(call.calldata.count)!) + callArray.append(contentsOf: call.calldata) + } + + return callArray +} diff --git a/Sources/Starknet/Data/Felt.swift b/Sources/Starknet/Data/Felt.swift index 5ff394592..adfcc00ff 100644 --- a/Sources/Starknet/Data/Felt.swift +++ b/Sources/Starknet/Data/Felt.swift @@ -148,7 +148,7 @@ public extension Felt { let pairs = hexString.components(withMaxLength: 2) return pairs.map { - return "\(UnicodeScalar(Int($0, radix: 16)!)!)" + "\(UnicodeScalar(Int($0, radix: 16)!)!)" }.joined() } } diff --git a/Sources/Starknet/Data/StarknetTypedData.swift b/Sources/Starknet/Data/StarknetTypedData.swift index bd9332f83..671368dc0 100644 --- a/Sources/Starknet/Data/StarknetTypedData.swift +++ b/Sources/Starknet/Data/StarknetTypedData.swift @@ -190,13 +190,13 @@ public struct StarknetTypedData: Codable, Equatable, Hashable { } public func getTypeHash(typeName: String) throws -> Felt { - starknetSelector(from: try encode(type: typeName)) + try starknetSelector(from: encode(type: typeName)) } public func getStructHash(typeName: String, data: [String: Element]) throws -> Felt { let encodedData = try encode(data: data, forType: typeName) - return StarknetCurve.pedersenOn([try getTypeHash(typeName: typeName)] + encodedData) + return try StarknetCurve.pedersenOn([getTypeHash(typeName: typeName)] + encodedData) } public func getStructHash(typeName: String, data: String) throws -> Felt { @@ -212,11 +212,11 @@ public struct StarknetTypedData: Codable, Equatable, Hashable { } public func getMessageHash(accountAddress: Felt) throws -> Felt { - StarknetCurve.pedersenOn( + try StarknetCurve.pedersenOn( Felt.fromShortString("StarkNet Message")!, - try getStructHash(typeName: "StarkNetDomain", data: domain), + getStructHash(typeName: "StarkNetDomain", data: domain), accountAddress, - try getStructHash(typeName: primaryType, data: message) + getStructHash(typeName: primaryType, data: message) ) } } diff --git a/Sources/Starknet/Data/Transaction/TransactionReceiptWrapper.swift b/Sources/Starknet/Data/Transaction/TransactionReceiptWrapper.swift index 5510e4b90..bb4796c2b 100644 --- a/Sources/Starknet/Data/Transaction/TransactionReceiptWrapper.swift +++ b/Sources/Starknet/Data/Transaction/TransactionReceiptWrapper.swift @@ -23,9 +23,9 @@ enum TransactionReceiptWrapper: Decodable { // Pending transaction won't have the block_hash value do { let _ = try container.decode(Felt.self, forKey: Keys.blockHash) - self = .common(try StarknetCommonTransactionReceipt(from: decoder)) + self = try .common(StarknetCommonTransactionReceipt(from: decoder)) } catch { - self = .pending(try StarknetPendingTransactionReceipt(from: decoder)) + self = try .pending(StarknetPendingTransactionReceipt(from: decoder)) } } } diff --git a/Sources/Starknet/Data/Transaction/TransactionWrapper.swift b/Sources/Starknet/Data/Transaction/TransactionWrapper.swift index 131d9407c..929b96fe4 100644 --- a/Sources/Starknet/Data/Transaction/TransactionWrapper.swift +++ b/Sources/Starknet/Data/Transaction/TransactionWrapper.swift @@ -41,19 +41,19 @@ enum TransactionWrapper: Decodable { switch (type, version) { case (.invoke, .one): - self = .invoke(try StarknetInvokeTransactionV1(from: decoder)) + self = try .invoke(StarknetInvokeTransactionV1(from: decoder)) case (.invoke, .zero): - self = .invokeV0(try StarknetInvokeTransactionV0(from: decoder)) + self = try .invokeV0(StarknetInvokeTransactionV0(from: decoder)) case (.declare, .one), (.declare, .zero): - self = .declareV1(try StarknetDeclareTransactionV1(from: decoder)) + self = try .declareV1(StarknetDeclareTransactionV1(from: decoder)) case (.declare, 2): - self = .declareV2(try StarknetDeclareTransactionV2(from: decoder)) + self = try .declareV2(StarknetDeclareTransactionV2(from: decoder)) case (.deploy, .zero): - self = .deploy(try StarknetDeployTransaction(from: decoder)) + self = try .deploy(StarknetDeployTransaction(from: decoder)) case (.deployAccount, .one): - self = .deployAccount(try StarknetDeployAccountTransaction(from: decoder)) + self = try .deployAccount(StarknetDeployAccountTransaction(from: decoder)) case (.l1Handler, .zero): - self = .l1Handler(try StarknetL1HandlerTransaction(from: decoder)) + self = try .l1Handler(StarknetL1HandlerTransaction(from: decoder)) default: throw DecodingError.dataCorruptedError(forKey: Keys.version, in: container, debugDescription: "Invalid transaction version (\(version) for transaction type (\(type))") } diff --git a/Tests/StarknetTests/Accounts/AccountTest.swift b/Tests/StarknetTests/Accounts/AccountTest.swift index fa7e02164..bde787a04 100644 --- a/Tests/StarknetTests/Accounts/AccountTest.swift +++ b/Tests/StarknetTests/Accounts/AccountTest.swift @@ -26,7 +26,7 @@ final class AccountTests: XCTestCase { provider = StarknetProvider(starknetChainId: .testnet, url: Self.devnetClient.rpcUrl)! signer = StarkCurveSigner(privateKey: "0x5421eb02ce8a5a972addcd89daefd93c")! - account = StarknetAccount(address: "0x5fa2c31b541653fc9db108f7d6857a1c2feda8e2abffbfa4ab4eaf1fcbfabd8", signer: signer, provider: provider) + account = StarknetAccount(address: "0x5fa2c31b541653fc9db108f7d6857a1c2feda8e2abffbfa4ab4eaf1fcbfabd8", signer: signer, provider: provider, cairoVersion: .zero) } override class func setUp() { @@ -107,11 +107,11 @@ final class AccountTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 1234)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress) - let nonce = (try? await newAccount.getNonce()) ?? .zero + let nonce = await (try? newAccount.getNonce()) ?? .zero let feeEstimate = try await newAccount.estimateDeployAccountFee(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero, nonce: nonce) let maxFee = estimatedFeeToMaxFee(feeEstimate.overallFee) diff --git a/Tests/StarknetTests/Data/ExecutionTests.swift b/Tests/StarknetTests/Data/ExecutionTests.swift new file mode 100644 index 000000000..75ab8536c --- /dev/null +++ b/Tests/StarknetTests/Data/ExecutionTests.swift @@ -0,0 +1,84 @@ +import XCTest + +@testable import Starknet + +final class ExecutionTests: XCTestCase { + static var devnetClient: DevnetClientProtocol! + + var provider: StarknetProviderProtocol! + var signer: StarknetSignerProtocol! + var account: StarknetAccountProtocol! + var balanceContractAddress: Felt! + + override func setUp() async throws { + try await super.setUp() + + if !Self.devnetClient.isRunning() { + try await Self.devnetClient.start() + } + + provider = StarknetProvider(starknetChainId: .testnet, url: Self.devnetClient.rpcUrl)! + signer = StarkCurveSigner(privateKey: "0x5421eb02ce8a5a972addcd89daefd93c")! + account = StarknetAccount(address: "0x5fa2c31b541653fc9db108f7d6857a1c2feda8e2abffbfa4ab4eaf1fcbfabd8", signer: signer, provider: provider, cairoVersion: .one) + balanceContractAddress = try await Self.devnetClient.deployContract(contractName: "balance", deprecated: true).address + } + + override class func setUp() { + super.setUp() + devnetClient = makeDevnetClient() + } + + override class func tearDown() { + super.tearDown() + + if let devnetClient = Self.devnetClient { + devnetClient.close() + } + } + + func testStarknetCallsToExecuteCalldataCairo1() throws { + let call1 = StarknetCall( + contractAddress: balanceContractAddress, + entrypoint: starknetSelector(from: "increase_balance"), + calldata: [Felt(10), Felt(20), Felt(30)] + ) + + let call2 = StarknetCall( + contractAddress: Felt(999), + entrypoint: starknetSelector(from: "empty_calldata"), + calldata: [] + ) + + let call3 = StarknetCall( + contractAddress: Felt(123), + entrypoint: starknetSelector(from: "another_method"), + calldata: [Felt(100), Felt(200)] + ) + let params = StarknetExecutionParams(nonce: .zero, maxFee: .zero) + + let signedTx = try account.sign(calls: [call1, call2, call3], params: params) + let expectedCalldata = [ + Felt(3), + balanceContractAddress, + starknetSelector(from: "increase_balance"), + Felt(3), + Felt(10), + Felt(20), + Felt(30), + Felt(999), + starknetSelector(from: "empty_calldata"), + Felt(0), + Felt(123), + starknetSelector(from: "another_method"), + Felt(2), + Felt(100), + Felt(200), + ] + + XCTAssertEqual(expectedCalldata, signedTx.calldata) + + let signedEmptyTx = try account.sign(calls: [], params: params) + + XCTAssertEqual([.zero], signedEmptyTx.calldata) + } +} diff --git a/Tests/StarknetTests/Providers/ProviderTests.swift b/Tests/StarknetTests/Providers/ProviderTests.swift index ba0635763..b2c59c8f3 100644 --- a/Tests/StarknetTests/Providers/ProviderTests.swift +++ b/Tests/StarknetTests/Providers/ProviderTests.swift @@ -92,7 +92,7 @@ final class ProviderTests: XCTestCase { let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_events") let contract = try await ProviderTests.devnetClient.deployContract(contractName: "events", deprecated: true) let sigerProtocol = StarkCurveSigner(privateKey: acc.details.privateKey) - let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider) + let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider, cairoVersion: .zero) let call = StarknetCall(contractAddress: contract.address, entrypoint: starknetSelector(from: "increase_balance"), calldata: [2137]) let _ = try await account.execute(call: call) @@ -123,7 +123,7 @@ final class ProviderTests: XCTestCase { let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_receipt") let contract = try await ProviderTests.devnetClient.deployContract(contractName: "events", deprecated: true) let sigerProtocol = StarkCurveSigner(privateKey: acc.details.privateKey) - let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider) + let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider, cairoVersion: .zero) let call = StarknetCall(contractAddress: contract.address, entrypoint: starknetSelector(from: "increase_balance"), calldata: [2137]) let invoke = try await account.execute(call: call) @@ -135,7 +135,7 @@ final class ProviderTests: XCTestCase { let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_estimate_fee") let contract = try await ProviderTests.devnetClient.deployContract(contractName: "balance", deprecated: true) let signer = StarkCurveSigner(privateKey: acc.details.privateKey)! - let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider) + let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider, cairoVersion: .zero) let nonce = try await account.getNonce() @@ -157,7 +157,7 @@ final class ProviderTests: XCTestCase { let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_simulate_transactions") let signer = StarkCurveSigner(privateKey: acc.details.privateKey)! let contract = try await ProviderTests.devnetClient.deployContract(contractName: "balance", deprecated: true) - let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider) + let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider, cairoVersion: .zero) let nonce = try await account.getNonce() @@ -170,7 +170,7 @@ final class ProviderTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 1234)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress)