diff --git a/contracts/safeguard/sentinel/Guard.sol b/contracts/safeguard/sentinel/Guard.sol index 6c03b46c0..700c2c0ca 100644 --- a/contracts/safeguard/sentinel/Guard.sol +++ b/contracts/safeguard/sentinel/Guard.sol @@ -47,15 +47,18 @@ abstract contract Guard is Ownable { emit GuardUpdated(msg.sender, GuardState.Relaxed); } - function addGuards(address[] calldata _accounts, uint256 _thresholdIncrement) external onlyOwner { - for (uint256 i = 0; i < _accounts.length; i++) { - _addGuard(_accounts[i]); + function updateGuards( + address[] calldata _add, + address[] calldata _remove, + uint256 _newRelaxThreshold + ) external onlyOwner { + for (uint256 i = 0; i < _add.length; i++) { + _addGuard(_add[i]); } - if (_thresholdIncrement > 0) { - require(_thresholdIncrement <= _accounts.length, "invalid threshold increment"); - _setRelaxThreshold(relaxThreshold + _thresholdIncrement); - _updateRelaxed(); + for (uint256 i = 0; i < _remove.length; i++) { + _removeGuard(_remove[i]); } + _setRelaxThreshold(_newRelaxThreshold); } function _addGuard(address _account) private { @@ -65,19 +68,6 @@ abstract contract Guard is Ownable { emit GuardUpdated(_account, GuardState.Guarded); } - function removeGuards(address[] calldata _accounts, uint256 _thresholdDecrement) external onlyOwner { - for (uint256 i = 0; i < _accounts.length; i++) { - _removeGuard(_accounts[i]); - } - if (_thresholdDecrement > 0) { - require(_thresholdDecrement <= _accounts.length, "invalid threshold decrement"); - _setRelaxThreshold(relaxThreshold - _thresholdDecrement); - } else if (relaxThreshold > guards.length) { - _setRelaxThreshold(guards.length); - } - _updateRelaxed(); - } - function _removeGuard(address _account) private { GuardState state = guardStates[_account]; require(state != GuardState.None, "account is not guard"); @@ -101,12 +91,12 @@ abstract contract Guard is Ownable { function setRelaxThreshold(uint256 _threshold) external onlyOwner { _setRelaxThreshold(_threshold); - _updateRelaxed(); } function _setRelaxThreshold(uint256 _threshold) private { require(_threshold <= guards.length, "invalid threshold"); relaxThreshold = _threshold; + _updateRelaxed(); emit RelaxThresholdUpdated(_threshold, guards.length); } diff --git a/test/Sentinel.spec.ts b/test/Sentinel.spec.ts index add697387..e71590675 100644 --- a/test/Sentinel.spec.ts +++ b/test/Sentinel.spec.ts @@ -26,7 +26,7 @@ describe('Sentinel Tests', function () { guards = [accounts[0], accounts[1]]; pausers = [accounts[2], accounts[3]]; governor = accounts[4]; - await sentinel.addGuards([guards[0].address, guards[1].address], 2); + await sentinel.updateGuards([guards[0].address, guards[1].address], [], 2); await sentinel.addPausers([pausers[0].address, pausers[1].address], [1, 2]); await sentinel.addGovernors([governor.address]); await bridge.addPauser(sentinel.address); @@ -35,7 +35,7 @@ describe('Sentinel Tests', function () { }); it('should pass guard tests', async function () { - await expect(sentinel.addGuards([pausers[0].address], 0)) + await expect(sentinel.updateGuards([pausers[0].address], [], 2)) .to.emit(sentinel, 'GuardUpdated') .withArgs(pausers[0].address, 1); @@ -46,7 +46,7 @@ describe('Sentinel Tests', function () { await sentinel.connect(guards[0]).relax(); await expect(sentinel.connect(guards[1]).relax()).to.emit(sentinel, 'RelaxStatusUpdated').withArgs(true); - await expect(sentinel.addGuards([pausers[1].address], 1)) + await expect(sentinel.updateGuards([pausers[1].address], [], 3)) .to.emit(sentinel, 'RelaxStatusUpdated') .withArgs(false); @@ -54,7 +54,7 @@ describe('Sentinel Tests', function () { expect(await sentinel.relaxThreshold()).to.equal(3); expect(await sentinel.numRelaxedGuards()).to.equal(2); - await expect(sentinel.removeGuards([guards[0].address, pausers[0].address], 1)) + await expect(sentinel.updateGuards([], [guards[0].address, pausers[0].address], 2)) .to.emit(sentinel, 'GuardUpdated') .withArgs(guards[0].address, 0) .to.emit(sentinel, 'GuardUpdated') @@ -63,7 +63,7 @@ describe('Sentinel Tests', function () { expect(await sentinel.relaxThreshold()).to.equal(2); expect(await sentinel.numRelaxedGuards()).to.equal(1); - await expect(sentinel.removeGuards([pausers[1].address], 1)) + await expect(sentinel.updateGuards([], [pausers[1].address], 1)) .to.emit(sentinel, 'RelaxStatusUpdated') .withArgs(true); expect(await sentinel.numGuards()).to.equal(1);