Skip to content

Commit

Permalink
Merge pull request #196 from defi-wonderland/dev
Browse files Browse the repository at this point in the history
feat: make smock work with the EDR-powered version of Hardhat (#195)
  • Loading branch information
0xGorilla authored Mar 5, 2024
2 parents 9cbac21 + 396fa41 commit 7e7d6f3
Show file tree
Hide file tree
Showing 16 changed files with 430 additions and 739 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/canary.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-node@v2
with:
node-version: "14.x"
node-version: "18.x"
registry-url: "https://registry.npmjs.org"
- uses: docker://pandoc/core:2.6
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install node
uses: actions/setup-node@v1
with:
node-version: "14.x"
node-version: "18.x"

- name: Install dependencies
run: yarn --frozen-lockfile
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-node@v2
with:
node-version: "14.x"
node-version: "18.x"
registry-url: "https://registry.npmjs.org"
- run: yarn
- run: yarn build
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install node
uses: actions/setup-node@v1
with:
node-version: "14.x"
node-version: "18.x"

- name: Install dependencies
run: yarn --frozen-lockfile
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ You can install Smock via npm or yarn:
npm install @defi-wonderland/smock
```

> **Note**: Starting from v2.4.0, Smock is only compatible with
> Hardhat v2.21.0 or later. If you are using an older version of Hardhat,
> please install Smock v2.3.5.
## Basic Usage

Smock is dead simple to use. Here's a basic example of how you might use
Expand Down
21 changes: 6 additions & 15 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@defi-wonderland/smock",
"version": "2.3.5",
"version": "2.4.0",
"description": "The Solidity mocking library",
"keywords": [
"ethereum",
Expand Down Expand Up @@ -52,17 +52,8 @@
"*.sol": "cross-env solhint --fix 'contracts/**/*.sol' 'interfaces/**/*.sol'",
"package.json": "sort-package-json"
},
"resolutions": {
"@ethereumjs/block": "3.2.1",
"@ethereumjs/blockchain": "5.2.1",
"@ethereumjs/common": "2.2.0",
"@ethereumjs/tx": "3.1.4",
"@ethereumjs/vm": "5.3.1"
},
"dependencies": {
"@nomicfoundation/ethereumjs-evm": "^1.0.0-rc.3",
"@nomicfoundation/ethereumjs-util": "^8.0.0-rc.3",
"@nomicfoundation/ethereumjs-vm": "^6.0.0-rc.3",
"@nomicfoundation/ethereumjs-util": "^9.0.4",
"diff": "^5.0.0",
"lodash.isequal": "^4.5.0",
"lodash.isequalwith": "^4.4.0",
Expand All @@ -84,15 +75,15 @@
"@types/lodash.isequal": "^4.5.5",
"@types/lodash.isequalwith": "^4.4.6",
"@types/mocha": "8.2.2",
"@types/node": "15.12.2",
"@types/node": "18.19.21",
"@types/readable-stream": "^2.3.13",
"@types/semver": "^7.3.8",
"chai": "4.3.4",
"chai-as-promised": "7.1.1",
"cross-env": "7.0.3",
"ethereum-waffle": "3.4.0",
"ethers": "5.4.1",
"hardhat": "2.11.0",
"hardhat": "^2.21.0",
"hardhat-preprocessor": "0.1.4",
"husky": "6.0.0",
"inquirer": "8.1.0",
Expand All @@ -109,14 +100,14 @@
"ts-node": "10.0.0",
"tsconfig-paths": "^3.9.0",
"typechain": "5.1.1",
"typescript": "4.5.2"
"typescript": "4.9.5"
},
"peerDependencies": {
"@ethersproject/abi": "^5",
"@ethersproject/abstract-provider": "^5",
"@ethersproject/abstract-signer": "^5",
"@nomiclabs/hardhat-ethers": "^2",
"ethers": "^5",
"hardhat": "^2"
"hardhat": "^2.21.0"
}
}
60 changes: 38 additions & 22 deletions src/factories/smock-contract.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
import { Message } from '@nomicfoundation/ethereumjs-evm/dist/message';
import { Address } from '@nomicfoundation/ethereumjs-util';
import { FactoryOptions } from '@nomiclabs/hardhat-ethers/types';
import assert from 'assert';
import { BaseContract, BigNumber, ContractFactory, ethers } from 'ethers';
import { Interface } from 'ethers/lib/utils';
import { ethers as hardhatEthers } from 'hardhat';
import { Observable } from 'rxjs';
import { distinct, filter, map, share, withLatestFrom } from 'rxjs/operators';
import { filter, map, share } from 'rxjs/operators';
import { EditableStorageLogic as EditableStorage } from '../logic/editable-storage-logic';
import { ProgrammableFunctionLogic, SafeProgrammableContract } from '../logic/programmable-function-logic';
import { ReadableStorageLogic as ReadableStorage } from '../logic/readable-storage-logic';
import { ObservableVM } from '../observable-vm';
import { Sandbox } from '../sandbox';
import { ContractCall, FakeContract, MockContractFactory, ProgrammableContractFunction, ProgrammedReturnValue } from '../types';
import { ContractCall, FakeContract, Message, MockContractFactory, ProgrammableContractFunction, ProgrammedReturnValue } from '../types';
import { convertPojoToStruct, fromFancyAddress, impersonate, isPojo, toFancyAddress, toHexString } from '../utils';
import { getStorageLayout } from '../utils/storage';

export async function createFakeContract<Contract extends BaseContract>(
vm: ObservableVM,
address: string,
contractInterface: ethers.utils.Interface,
provider: ethers.providers.Provider
provider: ethers.providers.Provider,
addFunctionToMap: (address: string, sighash: string | null, functionLogic: ProgrammableFunctionLogic) => void
): Promise<FakeContract<Contract>> {
const fake = (await initContract(vm, address, contractInterface, provider)) as unknown as FakeContract<Contract>;
const contractFunctions = getContractFunctionsNameAndSighash(contractInterface, Object.keys(fake.functions));

// attach to every contract function, all the programmable and watchable logic
contractFunctions.forEach(([sighash, name]) => {
const { encoder, calls$, results$ } = getFunctionEventData(vm, contractInterface, fake.address, sighash);
const functionLogic = new SafeProgrammableContract(name, calls$, results$, encoder);
const { encoder, calls$ } = getFunctionEventData(vm, contractInterface, fake.address, sighash);
const functionLogic = new SafeProgrammableContract(contractInterface, sighash, name, calls$, encoder);
fillProgrammableContractFunction(fake[name], functionLogic);
addFunctionToMap(fake.address, sighash, functionLogic);
});

return fake;
Expand All @@ -36,7 +39,8 @@ export async function createFakeContract<Contract extends BaseContract>(
function mockifyContractFactory<T extends ContractFactory>(
vm: ObservableVM,
contractName: string,
factory: MockContractFactory<T>
factory: MockContractFactory<T>,
addFunctionToMap: (address: string, sighash: string | null, functionLogic: ProgrammableFunctionLogic) => void
): MockContractFactory<T> {
const realDeploy = factory.deploy;
factory.deploy = async (...args: Parameters<T['deploy']>) => {
Expand All @@ -45,9 +49,10 @@ function mockifyContractFactory<T extends ContractFactory>(

// attach to every contract function, all the programmable and watchable logic
contractFunctions.forEach(([sighash, name]) => {
const { encoder, calls$, results$ } = getFunctionEventData(vm, mock.interface, mock.address, sighash);
const functionLogic = new ProgrammableFunctionLogic(name, calls$, results$, encoder);
const { encoder, calls$ } = getFunctionEventData(vm, mock.interface, mock.address, sighash);
const functionLogic = new ProgrammableFunctionLogic(mock.interface, sighash, name, calls$, encoder);
fillProgrammableContractFunction(mock[name], functionLogic);
addFunctionToMap(mock.address, sighash, functionLogic);
});

// attach to every internal variable, all the editable logic
Expand All @@ -66,7 +71,7 @@ function mockifyContractFactory<T extends ContractFactory>(
const realConnect = factory.connect;
factory.connect = (...args: Parameters<T['connect']>): MockContractFactory<T> => {
const newFactory = realConnect.apply(factory, args) as MockContractFactory<T>;
return mockifyContractFactory(vm, contractName, newFactory);
return mockifyContractFactory(vm, contractName, newFactory, addFunctionToMap);
};

return factory;
Expand All @@ -75,10 +80,11 @@ function mockifyContractFactory<T extends ContractFactory>(
export async function createMockContractFactory<T extends ContractFactory>(
vm: ObservableVM,
contractName: string,
addFunctionToMap: (address: string, sighash: string | null, functionLogic: ProgrammableFunctionLogic) => void,
signerOrOptions?: ethers.Signer | FactoryOptions
): Promise<MockContractFactory<T>> {
const factory = (await hardhatEthers.getContractFactory(contractName, signerOrOptions)) as unknown as MockContractFactory<T>;
return mockifyContractFactory(vm, contractName, factory);
return mockifyContractFactory(vm, contractName, factory, addFunctionToMap);
}

async function initContract(
Expand All @@ -104,14 +110,8 @@ function getFunctionEventData(vm: ObservableVM, contractInterface: ethers.utils.
const encoder = getFunctionEncoder(contractInterface, sighash);
// Filter only the calls that correspond to this function, from vm beforeMessages
const calls$ = parseAndFilterBeforeMessages(vm.getBeforeMessages(), contractInterface, contractAddress, sighash);
// Get every result that comes right after a call to this function
const results$ = vm.getAfterMessages().pipe(
withLatestFrom(calls$),
distinct(([, call]) => call),
map(([answer]) => answer)
);

return { encoder, calls$, results$ };
return { encoder, calls$ };
}

function getFunctionEncoder(contractInterface: ethers.utils.Interface, sighash: string | null): (values?: ProgrammedReturnValue) => string {
Expand Down Expand Up @@ -157,7 +157,7 @@ function parseAndFilterBeforeMessages(
}),
// Ensure the message is directed to this contract
filter((message) => {
const target = message.delegatecall ? message.codeAddress : message.to;
const target = isDelegated(message) ? message.codeAddress : message.to;
return target?.toString().toLowerCase() === contractAddress.toLowerCase();
}),
map((message) => parseMessage(message, contractInterface, sighash)),
Expand Down Expand Up @@ -206,12 +206,28 @@ function parseMessage(message: Message, contractInterface: Interface, sighash: s
args: sighash === null ? toHexString(message.data) : getMessageArgs(message.data, contractInterface, sighash),
nonce: Sandbox.getNextNonce(),
value: BigNumber.from(message.value.toString()),
target: fromFancyAddress(message.delegatecall ? message.codeAddress : message.to!),
delegatedFrom: message.delegatecall ? fromFancyAddress(message.to!) : undefined,
target: targetAddres(message),
delegatedFrom: isDelegated(message) ? fromFancyAddress(message.to!) : undefined,
};
}

function getMessageArgs(messageData: Buffer, contractInterface: Interface, sighash: string): unknown[] {
function targetAddres(message: Message): string {
assert(message.to !== undefined, 'Message should have a target address');

if (message.codeAddress !== undefined && message.to! !== message.caller && message.caller !== message.codeAddress) {
return fromFancyAddress(message.codeAddress);
} else {
return fromFancyAddress(message.to!);
}
}

function isDelegated(message: Message): boolean {
assert(message.to !== undefined, 'Message should have a target address');

return message.codeAddress !== undefined && message.to !== message.caller && message.caller !== message.codeAddress;
}

export function getMessageArgs(messageData: Buffer, contractInterface: Interface, sighash: string): unknown[] {
try {
return contractInterface.decodeFunctionData(contractInterface.getFunction(sighash).format(), toHexString(messageData)) as unknown[];
} catch (err) {
Expand Down
44 changes: 21 additions & 23 deletions src/logic/programmable-function-logic.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { EVMResult } from '@nomicfoundation/ethereumjs-evm/dist/evm';
import { EvmError } from '@nomicfoundation/ethereumjs-evm/dist/exceptions';
import { Address } from '@nomicfoundation/ethereumjs-util';
import { ethers } from 'ethers';
import { Interface } from 'ethers/lib/utils';
import { findLast } from 'lodash';
import { Observable, withLatestFrom } from 'rxjs';
import { Observable } from 'rxjs';
import { getMessageArgs } from '../factories/smock-contract';
import { ContractCall, ProgrammedReturnValue, WhenCalledWithChain } from '../index';
import { WatchableFunctionLogic } from '../logic/watchable-function-logic';
import { fromHexString } from '../utils';
import { fromHexString, toHexString } from '../utils';

const EMPTY_ANSWER: Buffer = fromHexString('0x' + '00'.repeat(2048));

Expand All @@ -26,20 +27,15 @@ export class ProgrammableFunctionLogic extends WatchableFunctionLogic {
protected answerByArgs: { answer: ProgrammedAnswer; args: unknown[] }[] = [];

constructor(
private contractInterface: Interface,
private sighash: string | null,
name: string,
calls$: Observable<ContractCall>,
results$: Observable<EVMResult>,
encoder: (values?: ProgrammedReturnValue) => string
) {
super(name, calls$);

this.encoder = encoder;

// Intercept every result of this programmableFunctionLogic
results$.pipe(withLatestFrom(calls$)).subscribe(async ([result, call]) => {
// Modify it with the corresponding answer
await this.modifyAnswer(result, call);
});
}

returns(value?: ProgrammedReturnValue): void {
Expand Down Expand Up @@ -82,40 +78,42 @@ export class ProgrammableFunctionLogic extends WatchableFunctionLogic {
this.answerByArgs = [];
}

private async modifyAnswer(result: EVMResult, call: ContractCall): Promise<void> {
const answer = this.getCallAnswer(call);
async getEncodedCallAnswer(data: Buffer): Promise<[result: Buffer, shouldRevert: boolean] | undefined> {
this.callCount++;

const answer = this.getCallAnswer(data);
if (answer) {
result.execResult.gas = BigInt(0);
if (answer.shouldRevert) {
result.execResult.exceptionError = new EvmError('smock revert' as any);
result.execResult.returnValue = this.encodeRevertReason(answer.value);
} else {
result.execResult.exceptionError = undefined;
result.execResult.returnValue = await this.encodeValue(answer.value, call);
return [this.encodeRevertReason(answer.value), answer.shouldRevert];
}

return [await this.encodeValue(answer.value, data), answer.shouldRevert];
}
}

private getCallAnswer(call: ContractCall): ProgrammedAnswer | undefined {
private getCallAnswer(data: Buffer): ProgrammedAnswer | undefined {
const args = this.sighash === null ? toHexString(data) : getMessageArgs(data, this.contractInterface, this.sighash);

let answer: ProgrammedAnswer | undefined;

// if there is an answer for this call index, return it
answer = this.answerByIndex[this.getCallCount() - 1];
if (answer) return answer;

// if there is an answer for this call arguments, return it
answer = findLast(this.answerByArgs, (option) => this.isDeepEqual(option.args, call.args))?.answer;
answer = findLast(this.answerByArgs, (option) => this.isDeepEqual(option.args, args))?.answer;
if (answer) return answer;

// return the default answer
return this.defaultAnswer;
}

private async encodeValue(value: ProgrammedReturnValue, call: ContractCall): Promise<Buffer> {
private async encodeValue(value: ProgrammedReturnValue, data: Buffer): Promise<Buffer> {
if (value === undefined) return EMPTY_ANSWER;

let toEncode = typeof value === 'function' ? await value(call.args) : value;
const args = this.sighash === null ? toHexString(data) : getMessageArgs(data, this.contractInterface, this.sighash);

let toEncode = typeof value === 'function' ? await value(args) : value;

let encodedReturnValue: string = '0x';
try {
Expand Down
7 changes: 6 additions & 1 deletion src/logic/readable-storage-logic.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { stripZeros } from 'ethers/lib/utils';
import { SmockVMManager } from '../types';
import { fromHexString, remove0x, toFancyAddress, toHexString } from '../utils';
import {
Expand Down Expand Up @@ -31,7 +32,11 @@ export class ReadableStorageLogic {
slots.map(async (slotKeyPair) => ({
...slotKeyPair,
value: remove0x(
toHexString(await this.vmManager.getContractStorage(toFancyAddress(this.contractAddress), fromHexString(slotKeyPair.key)))
toHexString(
Buffer.from(
stripZeros(await this.vmManager.getContractStorage(toFancyAddress(this.contractAddress), fromHexString(slotKeyPair.key)))
)
)
),
}))
);
Expand Down
4 changes: 3 additions & 1 deletion src/logic/watchable-function-logic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { convertStructToPojo, getObjectAndStruct, humanizeTimes } from '../utils
export class WatchableFunctionLogic {
protected name: string;
protected callHistory: ContractCall[] = [];
protected callCount = 0;

constructor(name: string, calls$: Observable<ContractCall>) {
this.name = name;
Expand Down Expand Up @@ -107,7 +108,7 @@ export class WatchableFunctionLogic {
}

getCallCount(): number {
return this.callHistory.length;
return this.callCount;
}

getCalled(): boolean {
Expand All @@ -127,6 +128,7 @@ export class WatchableFunctionLogic {
}

protected reset() {
this.callCount = 0;
this.callHistory = [];
}

Expand Down
Loading

0 comments on commit 7e7d6f3

Please sign in to comment.