Skip to content

Commit

Permalink
Added simulateTransaction endpoint to the rpc provider (#78)
Browse files Browse the repository at this point in the history
Added simulateTransaction endpoint
  • Loading branch information
bartekryba committed Jul 5, 2023
1 parent 5f0b855 commit a7218ea
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 12 deletions.
10 changes: 10 additions & 0 deletions Sources/Starknet/Data/Events.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ public struct StarknetEvent: Decodable, Equatable {
case data
}
}

public struct StarknetEventContent: Decodable, Equatable {
public let keys: [Felt]
public let data: [Felt]

enum CodingKeys: String, CodingKey {
case keys
case data
}
}
167 changes: 167 additions & 0 deletions Sources/Starknet/Data/Transaction/TransactionTrace.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import Foundation

public enum StarknetEntryPointType: String, Decodable {
case external = "EXTERNAL"
case l1Handler = "L1_HANDLER"
case constructor = "CONSTRUCTOR"
}

public enum StarknetCallType: String, Decodable {
case call = "CALL"
case libraryCall = "LIBRARY_CALL"
}

public enum StarknetSimulationFlag: String, Codable {
case skipValidate = "SKIP_VALIDATE"
case skipExecute = "SKIP_EXECUTE"
}

public struct StarknetFunctionInvocation: Decodable, Equatable {
public let contractAddress: Felt
public let entrypoint: Felt
public let calldata: StarknetCalldata
public let callerAddress: Felt?
public let classHash: Felt?
public let entryPointType: StarknetEntryPointType?
public let callType: StarknetCallType?
public let result: [Felt]?
public let calls: [StarknetFunctionInvocation]?
public let events: [StarknetEventContent]?
public let messages: [MessageToL1]?

private enum CodingKeys: String, CodingKey {
case contractAddress = "contract_address"
case entrypoint = "entry_point_selector"
case calldata
case callerAddress = "caller_address"
case classHash = "class_hash"
case entryPointType = "entry_point_type"
case callType = "call_type"
case result
case calls
case events
case messages
}
}

public protocol StarknetTransactionTrace: Decodable, Equatable {}

public struct StarknetInvokeTransactionTrace: StarknetTransactionTrace {
public let validateInvocation: StarknetFunctionInvocation?
public let executeInvocation: StarknetFunctionInvocation?
public let feeTransferInvocation: StarknetFunctionInvocation?

private enum CodingKeys: String, CodingKey {
case validateInvocation = "validate_invocation"
case executeInvocation = "execute_invocation"
case feeTransferInvocation = "fee_transfer_invocation"
}
}

public struct StarknetDeployAccountTransactionTrace: StarknetTransactionTrace {
public let validateInvocation: StarknetFunctionInvocation?
public let constructorInvocation: StarknetFunctionInvocation?
public let feeTransferInvocation: StarknetFunctionInvocation?

private enum CodingKeys: String, CodingKey {
case validateInvocation = "validate_invocation"
case constructorInvocation = "constructor_invocation"
case feeTransferInvocation = "fee_transfer_invocation"
}
}

public struct StarknetL1HandlerTransactionTrace: StarknetTransactionTrace {
public let functionInvocation: StarknetFunctionInvocation?

private enum CodingKeys: String, CodingKey {
case functionInvocation = "function_invocation"
}
}

enum StarknetTransactionTraceWrapper: Decodable {
fileprivate enum Keys: String, CodingKey {
case validateInvocation = "validate_invocation"
case executeInvocation = "execute_invocation"
case feeTransferInvocation = "fee_transfer_invocation"
case constructorInvocation = "constructor_invocation"
case functionInvocation = "function_invocation"
}

case invoke(StarknetInvokeTransactionTrace)
case deployAccount(StarknetDeployAccountTransactionTrace)
case l1Handler(StarknetL1HandlerTransactionTrace)

public var transactionTrace: any StarknetTransactionTrace {
switch self {
case let .invoke(txTrace):
return txTrace
case let .deployAccount(txTrace):
return txTrace
case let .l1Handler(txTrace):
return txTrace
}
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: Keys.self)

// Invocations can be null, so `if let = try?` syntax won't work here.
do {
let validateInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .validateInvocation)
let executeInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .executeInvocation)
let feeTransferInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .feeTransferInvocation)

self = .invoke(StarknetInvokeTransactionTrace(
validateInvocation: validateInvocation,
executeInvocation: executeInvocation,
feeTransferInvocation: feeTransferInvocation
))
return
} catch {}

do {
let validateInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .validateInvocation)
let constructorInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .constructorInvocation)
let feeTransferInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .feeTransferInvocation)

self = .deployAccount(StarknetDeployAccountTransactionTrace(
validateInvocation: validateInvocation,
constructorInvocation: constructorInvocation,
feeTransferInvocation: feeTransferInvocation
))
return
} catch {}

do {
let functionInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .functionInvocation)

self = .l1Handler(StarknetL1HandlerTransactionTrace(
functionInvocation: functionInvocation
))
return
} catch {}

throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: container.codingPath,
debugDescription: "Invalid transaction trace"
))
}
}

public struct StarknetSimulatedTransaction: Decodable {
public let transactionTrace: any StarknetTransactionTrace
public let feeEstimation: StarknetFeeEstimate

enum CodingKeys: String, CodingKey {
case transactionTrace = "transaction_trace"
case feeEstimation = "fee_estimation"
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)

transactionTrace = try container.decode(StarknetTransactionTraceWrapper.self, forKey: .transactionTrace).transactionTrace
feeEstimation = try container.decode(StarknetFeeEstimate.self, forKey: .feeEstimation)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ enum JsonRpcMethod: String, Encodable {
case getTransactionByHash = "starknet_getTransactionByHash"
case getTransactionByBlockIdAndIndex = "starknet_getTransactionByBlockIdAndIndex"
case getTransactionReceipt = "starknet_getTransactionReceipt"
case simulateTransaction = "starknet_simulateTransaction"
}
40 changes: 31 additions & 9 deletions Sources/Starknet/Providers/StarknetProvider/JsonRpcParams.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ struct AddInvokeTransactionParams: Encodable {
}
}

// Workaround to allow encoding polymorphic array
struct WrappedSequencerTransaction: Encodable {
let transaction: any StarknetSequencerTransaction

func encode(to encoder: Encoder) throws {
try transaction.encode(to: encoder)
}
}

struct EstimateFeeParams: Encodable {
let request: [any StarknetSequencerTransaction]
let blockId: StarknetBlockId

// Walkaround to allow encoding polymorphic array
struct WrappedSequencerTransaction: Encodable {
let transaction: any StarknetSequencerTransaction

func encode(to encoder: Encoder) throws {
try transaction.encode(to: encoder)
}
}

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)

Expand Down Expand Up @@ -104,4 +104,26 @@ struct GetTransactionReceiptPayload: Encodable {
}
}

struct SimulateTransactionsParams: Encodable {
let transactions: [any StarknetSequencerTransaction]
let blockId: StarknetBlockId
let simulationFlags: Set<StarknetSimulationFlag>

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)

let wrappedTransactions = transactions.map { WrappedSequencerTransaction(transaction: $0) }

try container.encode(wrappedTransactions, forKey: .transactions)
try container.encode(blockId, forKey: .blockId)
try container.encode(simulationFlags, forKey: .simulationFlags)
}

enum CodingKeys: String, CodingKey {
case transactions
case blockId = "block_id"
case simulationFlags = "simulation_flags"
}
}

struct EmptyParams: Encodable {}
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,12 @@ public class StarknetProvider: StarknetProviderProtocol {

return result.transactionReceipt
}

public func simulateTransactions(_ transactions: [any StarknetSequencerTransaction], at blockId: StarknetBlockId, simulationFlags: Set<StarknetSimulationFlag>) async throws -> [StarknetSimulatedTransaction] {
let params = SimulateTransactionsParams(transactions: transactions, blockId: blockId, simulationFlags: simulationFlags)

let result = try await makeRequest(method: .simulateTransaction, params: params, receive: [StarknetSimulatedTransaction].self)

return result
}
}
19 changes: 18 additions & 1 deletion Sources/Starknet/Providers/StarknetProviderProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,17 @@ public protocol StarknetProviderProtocol {
/// - txHash : the hash of the requested transaction
/// - Returns: receipt of a transaction identified by given hash
func getTransactionReceiptBy(hash: Felt) async throws -> StarknetTransactionReceipt

/// Simulate running a given list of transactions, and generate the execution trace
///
/// - Parameters:
/// - transactions: list of transactions to simulate
/// - blockId: block used to run the simulation
/// - simulationFlags: a set of simulation flags
func simulateTransactions(_ transactions: [any StarknetSequencerTransaction], at blockId: StarknetBlockId, simulationFlags: Set<StarknetSimulationFlag>) async throws -> [StarknetSimulatedTransaction]
}

private let defaultBlockId = StarknetBlockId.tag(.pending)
let defaultBlockId = StarknetBlockId.tag(.pending)

public extension StarknetProviderProtocol {
/// Call starknet contract in the pending block.
Expand Down Expand Up @@ -157,4 +165,13 @@ public extension StarknetProviderProtocol {
func getClassHashAt(_ address: Felt) async throws -> Felt {
try await getClassHashAt(address, at: defaultBlockId)
}

/// Simulate running a given list of transactions in the latest block, and generate the execution trace
///
/// - Parameters:
/// - transactions: list of transactions to simulate
/// - simulationFlags: a set of simulation flags
func simulateTransactions(_ transactions: [any StarknetSequencerTransaction], simulationFlags: Set<StarknetSimulationFlag>) async throws -> [StarknetSimulatedTransaction] {
try await simulateTransactions(transactions, at: defaultBlockId, simulationFlags: simulationFlags)
}
}
57 changes: 57 additions & 0 deletions Tests/StarknetTests/Providers/ProviderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ final class ProviderTests: XCTestCase {
To run, make sure you're running starknet-devnet on port 5050, with seed 0
*/
static var devnetClient: DevnetClientProtocol!

var provider: StarknetProviderProtocol!

override class func setUp() {
Expand Down Expand Up @@ -151,4 +152,60 @@ final class ProviderTests: XCTestCase {

XCTAssertEqual(fees.count, 2)
}

func testSimulateTransactions() async throws {
let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_simulate_transactions")
let signer = StarkCurveSigner(privateKey: acc.details.privateKey)!
let contract = try await ProviderTests.devnetClient.deployContract(contractName: "balance", deprecated: true)
let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider)

let nonce = try await account.getNonce()

let call = StarknetCall(contractAddress: contract.address, entrypoint: starknetSelector(from: "increase_balance"), calldata: [1000])
let params = StarknetExecutionParams(nonce: nonce, maxFee: 1_000_000_000_000)

let invokeTx = try account.sign(calls: [call], params: params, forFeeEstimation: true)

let accountClassHash = try await provider.getClassHashAt(account.address)
let newSigner = StarkCurveSigner(privateKey: 1234)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

let newAccountParams = StarknetExecutionParams(nonce: 0, maxFee: 0)
let deployAccountTx = try newAccount.signDeployAccount(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero, params: newAccountParams, forFeeEstimation: true)

let simulations = try await provider.simulateTransactions([invokeTx, deployAccountTx], at: .tag(.latest), simulationFlags: [])

XCTAssertEqual(simulations.count, 2)
XCTAssertTrue(simulations[0].transactionTrace is StarknetInvokeTransactionTrace)
XCTAssertTrue(simulations[1].transactionTrace is StarknetDeployAccountTransactionTrace)

let invokeWithoutSignature = StarknetSequencerInvokeTransaction(
senderAddress: invokeTx.senderAddress,
calldata: invokeTx.calldata,
signature: [],
maxFee: invokeTx.maxFee,
nonce: invokeTx.nonce,
version: invokeTx.version
)

let deployAccountWithoutSignature = StarknetSequencerDeployAccountTransaction(
signature: [],
maxFee: deployAccountTx.maxFee,
nonce: deployAccountTx.nonce,
contractAddressSalt: deployAccountTx.contractAddressSalt,
constructorCalldata: deployAccountTx.constructorCalldata,
classHash: deployAccountTx.classHash,
version: deployAccountTx.version
)

let simulations2 = try await provider.simulateTransactions([invokeWithoutSignature, deployAccountWithoutSignature], at: .tag(.latest), simulationFlags: [.skipValidate])

XCTAssertEqual(simulations2.count, 2)
XCTAssertTrue(simulations2[0].transactionTrace is StarknetInvokeTransactionTrace)
XCTAssertTrue(simulations2[1].transactionTrace is StarknetDeployAccountTransactionTrace)
}
}
6 changes: 6 additions & 0 deletions Tests/StarknetTests/Utils/DevnetClient/DevnetClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func makeDevnetClient() -> DevnetClientProtocol {
private let devnetPath: String
private let starknetPath: String

private var deployedContracts: [String: TransactionResult] = [:]

let gatewayUrl: String
let feederGatewayUrl: String
let rpcUrl: String
Expand Down Expand Up @@ -220,6 +222,10 @@ func makeDevnetClient() -> DevnetClientProtocol {
public func deployContract(contractName: String, deprecated: Bool) async throws -> TransactionResult {
try guardDevnetIsRunning()

if let transactionResult = deployedContracts["contractName"] {
return transactionResult
}

let classHash = try await declareContract(contractName: contractName, deprecated: deprecated)

let params = [
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
starknet-devnet==0.5.2
cairo-lang==0.11.1.1
starknet-devnet==0.5.5
cairo-lang==0.12.0

0 comments on commit a7218ea

Please sign in to comment.