Skip to content

Commit

Permalink
let contracts use multicall scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
kroggen committed Jul 22, 2024
1 parent 4ebaf9d commit dde3772
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 16 deletions.
41 changes: 34 additions & 7 deletions contract/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ func Call(

// get contract
if ctx.isMultiCall {
bytecode = getMultiCallContract(contractState)
bytecode = getMultiCallContractCode(contractState)
} else {
bytecode = getContractCode(contractState, ctx.bs)
}
Expand Down Expand Up @@ -1262,14 +1262,23 @@ func getCode(contractState *statedb.ContractState, bs *state.BlockState) ([]byte
var code []byte
var err error

if contractState.IsMultiCall() {
return getMultiCallCode(contractState), nil
}

// try to get the code from the blockstate cache
code = bs.GetCode(contractState.GetAccountID())
if code != nil {
return code, nil
}

// get the code from the contract state
code, err = contractState.GetCode()
if err != nil {
return nil, err
}

// add the code to the blockstate cache
bs.AddCode(contractState.GetAccountID(), code)

return code, nil
Expand All @@ -1287,7 +1296,15 @@ func getContractCode(contractState *statedb.ContractState, bs *state.BlockState)
return luacUtil.LuaCode(code).ByteCode()
}

func getMultiCallContract(contractState *statedb.ContractState) []byte {
func getMultiCallContractCode(contractState *statedb.ContractState) []byte {
code := getMultiCallCode(contractState)
if code == nil {
return nil
}
return luacUtil.LuaCode(code).ByteCode()
}

func getMultiCallCode(contractState *statedb.ContractState) []byte {
if multicall_compiled == nil {
// compile the Lua code used to execute multicall txns
var err error
Expand All @@ -1299,16 +1316,21 @@ func getMultiCallContract(contractState *statedb.ContractState) []byte {
}
// set and return the compiled code
contractState.SetMultiCallCode(multicall_compiled)
return multicall_compiled.ByteCode()
return multicall_compiled
}

func GetABI(contractState *statedb.ContractState, bs *state.BlockState) (*types.ABI, error) {
var abi *types.ABI

abi = bs.GetABI(contractState.GetAccountID())
if abi != nil {
return abi, nil
if !contractState.IsMultiCall() { // or IsBuiltinContract()
// try to get the ABI from the blockstate cache
abi = bs.GetABI(contractState.GetAccountID())
if abi != nil {
return abi, nil
}
}

// get the ABI from the contract state
code, err := getCode(contractState, bs)
if err != nil {
return nil, err
Expand All @@ -1326,7 +1348,12 @@ func GetABI(contractState *statedb.ContractState, bs *state.BlockState) (*types.
if err = jsonIter.Unmarshal(rawAbi, abi); err != nil {
return nil, err
}
bs.AddABI(contractState.GetAccountID(), abi)

if !contractState.IsMultiCall() { // or IsBuiltinContract()
// add the ABI to the blockstate cache
bs.AddABI(contractState.GetAccountID(), abi)
}

return abi, nil
}

Expand Down
41 changes: 33 additions & 8 deletions contract/vm_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,29 +344,54 @@ func luaDelegateCallContract(L *LState, service C.int, contractId *C.char,
return -1, C.CString("[Contract.LuaDelegateCallContract] contract state not found")
}

var isMultiCall bool
var cid []byte
var err error

// get the contract address
cid, err := getAddressNameResolved(contractIdStr, ctx.bs)
if err != nil {
return -1, C.CString("[Contract.LuaDelegateCallContract] invalid contractId: " + err.Error())
if contractIdStr == "multicall" {
isMultiCall = true
argsStr = fnameStr
fnameStr = "execute"
cid = ctx.curContract.contractId
} else {
cid, err = getAddressNameResolved(contractIdStr, ctx.bs)
if err != nil {
return -1, C.CString("[Contract.LuaDelegateCallContract] invalid contractId: " + err.Error())
}
}
aid := types.ToAccountID(cid)

// get the contract state
contractState, err := getOnlyContractState(ctx, cid)
var contractState *statedb.ContractState
if isMultiCall {
contractState = statedb.GetMultiCallState(cid, ctx.curContract.callState.ctrState.State)
} else {
contractState, err = getOnlyContractState(ctx, cid)
}
if err != nil {
return -1, C.CString("[Contract.LuaDelegateCallContract]getContractState error" + err.Error())
}

// check if the contract exists
bytecode := getContractCode(contractState, ctx.bs)
// get the contract code
var bytecode []byte
if isMultiCall {
bytecode = getMultiCallContractCode(contractState)
} else {
bytecode = getContractCode(contractState, ctx.bs)
}
if bytecode == nil {
return -1, C.CString("[Contract.LuaDelegateCallContract] cannot find contract " + contractIdStr)
}

// read the arguments for the contract call
var ci types.CallInfo
ci.Name = fnameStr
err = getCallInfo(&ci.Args, []byte(argsStr), cid)
if isMultiCall {
err = getMultiCallInfo(&ci, []byte(argsStr))
} else {
ci.Name = fnameStr
err = getCallInfo(&ci.Args, []byte(argsStr), cid)
}
if err != nil {
return -1, C.CString("[Contract.LuaDelegateCallContract] invalid arguments: " + err.Error())
}
Expand Down
11 changes: 10 additions & 1 deletion contract/vm_dummy/test_files/feature_multicall.lua
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ function recv_aergo()
-- does nothing
end

function default()
-- does nothing
end

function resend_to(address)
local amount = system.getAmount()
contract.send(address, amount)
Expand All @@ -109,5 +113,10 @@ function send_and_fail(address, amount)
assert(false, "this call should fail")
end

abi.payable(recv_aergo, resend_to, resend_and_fail)
function get_aergo_balance()
return contract.balance()
end

abi.payable(recv_aergo, default, resend_to, resend_and_fail)
abi.register(send_to, send_and_fail)
abi.register_view(get_aergo_balance)
55 changes: 55 additions & 0 deletions contract/vm_dummy/test_files/feature_multicall_contract.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
state.var {
dict = state.map()
}

function set_value(key, value)
dict[key] = value
end

function get_value(key)
return dict[key]
end

abi.register(set_value)
abi.register_view(get_value)

function call(...)
return contract.call(...)
end

function delegate_call(...)
return contract.delegatecall(...)
end

function multicall(script)
return contract.delegatecall("multicall", script)
end

function multicall_and_check(script)
local result1, result2 = contract.delegatecall("multicall", script)
assert(contract.balance() == "875000000000000000")
assert(contract.balance("AmhXhR3Eguhu5qjVoqcg7aCFMpw1GGZJfqDDqfy6RsTP7MrpWeJ9") == "125000000000000000")
return result1, result2
end

abi.register(call, delegate_call, multicall, multicall_and_check)

function recv_aergo()
-- does nothing
end

function default()
-- does nothing
end

function send_to(address, amount)
contract.send(address, amount)
end

function get_balance()
return contract.balance()
end

abi.payable(recv_aergo, default)
abi.register(send_to)
abi.register_view(get_balance)
Loading

0 comments on commit dde3772

Please sign in to comment.