Skip to content

Commit

Permalink
[ADD] Support for cairo1 account calldata (#82)
Browse files Browse the repository at this point in the history
* [ADD] Support for cairo1 account calldata #79
  • Loading branch information
dioKaratzas committed Jul 18, 2023
1 parent 9fd942a commit e4b41ed
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 26 deletions.
11 changes: 9 additions & 2 deletions Sources/Starknet/Accounts/StarknetAccount.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ public enum StarknetAccountError: Error {
case invalidResponse
}

public enum CairoVersion: String, Encodable {
case zero
case one
}

public class StarknetAccount: StarknetAccountProtocol {
private let version = Felt.one
private let cairoVersion: CairoVersion

private var estimateVersion: Felt {
Felt(BigUInt(2).power(128).advanced(by: BigInt(version.value)))!
Expand All @@ -17,10 +23,11 @@ public class StarknetAccount: StarknetAccountProtocol {
private let signer: StarknetSignerProtocol
private let provider: StarknetProviderProtocol

public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol) {
public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol, cairoVersion: CairoVersion) {
self.address = address
self.signer = signer
self.provider = provider
self.cairoVersion = cairoVersion
}

private func makeSequencerInvokeTransaction(calldata: StarknetCalldata, signature: StarknetSignature, params: StarknetExecutionParams, version: Felt) -> StarknetSequencerInvokeTransaction {
Expand All @@ -41,7 +48,7 @@ public class StarknetAccount: StarknetAccountProtocol {

public func sign(calls: [StarknetCall], params: StarknetExecutionParams, forFeeEstimation: Bool) throws -> StarknetSequencerInvokeTransaction {
let version = forFeeEstimation ? estimateVersion : version
let calldata = starknetCallsToExecuteCalldata(calls: calls)
let calldata = starknetCallsToExecuteCalldata(calls: calls, cairoVersion: cairoVersion)

let sequencerTransaction = makeSequencerInvokeTransaction(calldata: calldata, signature: [], params: params, version: version)

Expand Down
26 changes: 25 additions & 1 deletion Sources/Starknet/Data/Execution.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@ public struct StarknetOptionalExecutionParams {
}
}

public func starknetCallsToExecuteCalldata(calls: [StarknetCall]) -> [Felt] {
public func starknetCallsToExecuteCalldata(calls: [StarknetCall], cairoVersion: CairoVersion) -> [Felt] {
switch cairoVersion {
case .zero:
return starknetCallsToExecuteCalldataCairo0(calls: calls)
case .one:
return starknetCallsToExecuteCalldataCairo1(calls: calls)
}
}

private func starknetCallsToExecuteCalldataCairo0(calls: [StarknetCall]) -> [Felt] {
var wholeCalldata: [Felt] = []
var callArray: [Felt] = []

Expand All @@ -56,3 +65,18 @@ public func starknetCallsToExecuteCalldata(calls: [StarknetCall]) -> [Felt] {

return [Felt(calls.count)!] + callArray + [Felt(wholeCalldata.count)!] + wholeCalldata
}

private func starknetCallsToExecuteCalldataCairo1(calls: [StarknetCall]) -> [Felt] {
var callArray: [Felt] = []

callArray.append(Felt(calls.count)!)

calls.forEach { call in
callArray.append(call.contractAddress)
callArray.append(call.entrypoint)
callArray.append(Felt(call.calldata.count)!)
callArray.append(contentsOf: call.calldata)
}

return callArray
}
2 changes: 1 addition & 1 deletion Sources/Starknet/Data/Felt.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public extension Felt {
let pairs = hexString.components(withMaxLength: 2)

return pairs.map {
return "\(UnicodeScalar(Int($0, radix: 16)!)!)"
"\(UnicodeScalar(Int($0, radix: 16)!)!)"
}.joined()
}
}
Expand Down
10 changes: 5 additions & 5 deletions Sources/Starknet/Data/StarknetTypedData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
}

public func getTypeHash(typeName: String) throws -> Felt {
starknetSelector(from: try encode(type: typeName))
try starknetSelector(from: encode(type: typeName))
}

public func getStructHash(typeName: String, data: [String: Element]) throws -> Felt {
let encodedData = try encode(data: data, forType: typeName)

return StarknetCurve.pedersenOn([try getTypeHash(typeName: typeName)] + encodedData)
return try StarknetCurve.pedersenOn([getTypeHash(typeName: typeName)] + encodedData)
}

public func getStructHash(typeName: String, data: String) throws -> Felt {
Expand All @@ -212,11 +212,11 @@ public struct StarknetTypedData: Codable, Equatable, Hashable {
}

public func getMessageHash(accountAddress: Felt) throws -> Felt {
StarknetCurve.pedersenOn(
try StarknetCurve.pedersenOn(
Felt.fromShortString("StarkNet Message")!,
try getStructHash(typeName: "StarkNetDomain", data: domain),
getStructHash(typeName: "StarkNetDomain", data: domain),
accountAddress,
try getStructHash(typeName: primaryType, data: message)
getStructHash(typeName: primaryType, data: message)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ enum TransactionReceiptWrapper: Decodable {
// Pending transaction won't have the block_hash value
do {
let _ = try container.decode(Felt.self, forKey: Keys.blockHash)
self = .common(try StarknetCommonTransactionReceipt(from: decoder))
self = try .common(StarknetCommonTransactionReceipt(from: decoder))
} catch {
self = .pending(try StarknetPendingTransactionReceipt(from: decoder))
self = try .pending(StarknetPendingTransactionReceipt(from: decoder))
}
}
}
14 changes: 7 additions & 7 deletions Sources/Starknet/Data/Transaction/TransactionWrapper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ enum TransactionWrapper: Decodable {

switch (type, version) {
case (.invoke, .one):
self = .invoke(try StarknetInvokeTransactionV1(from: decoder))
self = try .invoke(StarknetInvokeTransactionV1(from: decoder))
case (.invoke, .zero):
self = .invokeV0(try StarknetInvokeTransactionV0(from: decoder))
self = try .invokeV0(StarknetInvokeTransactionV0(from: decoder))
case (.declare, .one), (.declare, .zero):
self = .declareV1(try StarknetDeclareTransactionV1(from: decoder))
self = try .declareV1(StarknetDeclareTransactionV1(from: decoder))
case (.declare, 2):
self = .declareV2(try StarknetDeclareTransactionV2(from: decoder))
self = try .declareV2(StarknetDeclareTransactionV2(from: decoder))
case (.deploy, .zero):
self = .deploy(try StarknetDeployTransaction(from: decoder))
self = try .deploy(StarknetDeployTransaction(from: decoder))
case (.deployAccount, .one):
self = .deployAccount(try StarknetDeployAccountTransaction(from: decoder))
self = try .deployAccount(StarknetDeployAccountTransaction(from: decoder))
case (.l1Handler, .zero):
self = .l1Handler(try StarknetL1HandlerTransaction(from: decoder))
self = try .l1Handler(StarknetL1HandlerTransaction(from: decoder))
default:
throw DecodingError.dataCorruptedError(forKey: Keys.version, in: container, debugDescription: "Invalid transaction version (\(version) for transaction type (\(type))")
}
Expand Down
6 changes: 3 additions & 3 deletions Tests/StarknetTests/Accounts/AccountTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ final class AccountTests: XCTestCase {

provider = StarknetProvider(starknetChainId: .testnet, url: Self.devnetClient.rpcUrl)!
signer = StarkCurveSigner(privateKey: "0x5421eb02ce8a5a972addcd89daefd93c")!
account = StarknetAccount(address: "0x5fa2c31b541653fc9db108f7d6857a1c2feda8e2abffbfa4ab4eaf1fcbfabd8", signer: signer, provider: provider)
account = StarknetAccount(address: "0x5fa2c31b541653fc9db108f7d6857a1c2feda8e2abffbfa4ab4eaf1fcbfabd8", signer: signer, provider: provider, cairoVersion: .zero)
}

override class func setUp() {
Expand Down Expand Up @@ -107,11 +107,11 @@ final class AccountTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 1234)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

let nonce = (try? await newAccount.getNonce()) ?? .zero
let nonce = await (try? newAccount.getNonce()) ?? .zero

let feeEstimate = try await newAccount.estimateDeployAccountFee(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero, nonce: nonce)
let maxFee = estimatedFeeToMaxFee(feeEstimate.overallFee)
Expand Down
84 changes: 84 additions & 0 deletions Tests/StarknetTests/Data/ExecutionTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import XCTest

@testable import Starknet

final class ExecutionTests: XCTestCase {
static var devnetClient: DevnetClientProtocol!

var provider: StarknetProviderProtocol!
var signer: StarknetSignerProtocol!
var account: StarknetAccountProtocol!
var balanceContractAddress: Felt!

override func setUp() async throws {
try await super.setUp()

if !Self.devnetClient.isRunning() {
try await Self.devnetClient.start()
}

provider = StarknetProvider(starknetChainId: .testnet, url: Self.devnetClient.rpcUrl)!
signer = StarkCurveSigner(privateKey: "0x5421eb02ce8a5a972addcd89daefd93c")!
account = StarknetAccount(address: "0x5fa2c31b541653fc9db108f7d6857a1c2feda8e2abffbfa4ab4eaf1fcbfabd8", signer: signer, provider: provider, cairoVersion: .one)
balanceContractAddress = try await Self.devnetClient.deployContract(contractName: "balance", deprecated: true).address
}

override class func setUp() {
super.setUp()
devnetClient = makeDevnetClient()
}

override class func tearDown() {
super.tearDown()

if let devnetClient = Self.devnetClient {
devnetClient.close()
}
}

func testStarknetCallsToExecuteCalldataCairo1() throws {
let call1 = StarknetCall(
contractAddress: balanceContractAddress,
entrypoint: starknetSelector(from: "increase_balance"),
calldata: [Felt(10), Felt(20), Felt(30)]
)

let call2 = StarknetCall(
contractAddress: Felt(999),
entrypoint: starknetSelector(from: "empty_calldata"),
calldata: []
)

let call3 = StarknetCall(
contractAddress: Felt(123),
entrypoint: starknetSelector(from: "another_method"),
calldata: [Felt(100), Felt(200)]
)
let params = StarknetExecutionParams(nonce: .zero, maxFee: .zero)

let signedTx = try account.sign(calls: [call1, call2, call3], params: params)
let expectedCalldata = [
Felt(3),
balanceContractAddress,
starknetSelector(from: "increase_balance"),
Felt(3),
Felt(10),
Felt(20),
Felt(30),
Felt(999),
starknetSelector(from: "empty_calldata"),
Felt(0),
Felt(123),
starknetSelector(from: "another_method"),
Felt(2),
Felt(100),
Felt(200),
]

XCTAssertEqual(expectedCalldata, signedTx.calldata)

let signedEmptyTx = try account.sign(calls: [], params: params)

XCTAssertEqual([.zero], signedEmptyTx.calldata)
}
}
10 changes: 5 additions & 5 deletions Tests/StarknetTests/Providers/ProviderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class ProviderTests: XCTestCase {
let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_events")
let contract = try await ProviderTests.devnetClient.deployContract(contractName: "events", deprecated: true)
let sigerProtocol = StarkCurveSigner(privateKey: acc.details.privateKey)
let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider)
let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider, cairoVersion: .zero)
let call = StarknetCall(contractAddress: contract.address, entrypoint: starknetSelector(from: "increase_balance"), calldata: [2137])
let _ = try await account.execute(call: call)

Expand Down Expand Up @@ -123,7 +123,7 @@ final class ProviderTests: XCTestCase {
let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_receipt")
let contract = try await ProviderTests.devnetClient.deployContract(contractName: "events", deprecated: true)
let sigerProtocol = StarkCurveSigner(privateKey: acc.details.privateKey)
let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider)
let account = StarknetAccount(address: acc.details.address, signer: sigerProtocol!, provider: provider, cairoVersion: .zero)
let call = StarknetCall(contractAddress: contract.address, entrypoint: starknetSelector(from: "increase_balance"), calldata: [2137])
let invoke = try await account.execute(call: call)

Expand All @@ -135,7 +135,7 @@ final class ProviderTests: XCTestCase {
let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_estimate_fee")
let contract = try await ProviderTests.devnetClient.deployContract(contractName: "balance", deprecated: true)
let signer = StarkCurveSigner(privateKey: acc.details.privateKey)!
let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider)
let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider, cairoVersion: .zero)

let nonce = try await account.getNonce()

Expand All @@ -157,7 +157,7 @@ final class ProviderTests: XCTestCase {
let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_simulate_transactions")
let signer = StarkCurveSigner(privateKey: acc.details.privateKey)!
let contract = try await ProviderTests.devnetClient.deployContract(contractName: "balance", deprecated: true)
let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider)
let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider, cairoVersion: .zero)

let nonce = try await account.getNonce()

Expand All @@ -170,7 +170,7 @@ final class ProviderTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 1234)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

Expand Down

0 comments on commit e4b41ed

Please sign in to comment.