diff --git a/core/vm/contracts_zkevm.go b/core/vm/contracts_zkevm.go index 895e1162820..3e96dd6ad52 100644 --- a/core/vm/contracts_zkevm.go +++ b/core/vm/contracts_zkevm.go @@ -400,18 +400,24 @@ func (c *bigModExp_zkevm) Run(input []byte) ([]byte, error) { mod = big.NewInt(0) ) - if len(input) >= 96 + int(baseLen) { - base = new(big.Int).SetBytes(getData(input, 96, uint64(baseLen))) + // Extract `base`, `exp`, and `mod` with padding as needed + baseData := getData(input, 96, uint64(baseLen)) + if uint64(len(baseData)) < baseLen { + baseData = common.RightPadBytes(baseData, int(baseLen)) } - if len(input) >= 96 + int(baseLen) + int(expLen) { - exp = new(big.Int).SetBytes(getData(input, 96 + uint64(baseLen), uint64(expLen))) - } - if len(input) >= 96 + int(baseLen) + int(expLen) + int(modLen) { - mod = new(big.Int).SetBytes(getData(input, 96 + uint64(baseLen) + uint64(expLen), uint64(modLen))) + base.SetBytes(baseData) + + expData := getData(input, 96+uint64(baseLen), uint64(expLen)) + if uint64(len(expData)) < expLen { + expData = common.RightPadBytes(expData, int(expLen)) } - if len(input) < 96 + int(baseLen) + int(expLen) + int(modLen) { - input = common.LeftPadBytes(input, 96 + int(baseLen) + int(expLen) + int(modLen)) + exp.SetBytes(expData) + + modData := getData(input, 96+uint64(baseLen)+uint64(expLen), uint64(modLen)) + if uint64(len(modData)) < modLen { + modData = common.RightPadBytes(modData, int(modLen)) } + mod.SetBytes(modData) // Retrieve the operands and execute the exponentiation var ( @@ -422,7 +428,7 @@ func (c *bigModExp_zkevm) Run(input []byte) ([]byte, error) { ) if modBitLen == 0 { - return []byte{}, nil + return common.LeftPadBytes([]byte{}, int(modLen)), nil } if baseBitLen == 0 {