Skip to content

Commit

Permalink
[#103] Use updated registry interface
Browse files Browse the repository at this point in the history
  • Loading branch information
akshay-ap committed Sep 26, 2023
1 parent 5d77d40 commit 2e70453
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 33 deletions.
9 changes: 7 additions & 2 deletions contracts/SafeProtocolRegistry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@ contract SafeProtocolRegistry is ISafeProtocolRegistry, Ownable2Step {
* @return listedAt Timestamp of listing the module. This value will be 0 if not listed.
* @return flaggedAt Timestamp of falgging the module. This value will be 0 if not flagged.
*/
function check(address module) external view returns (uint64 listedAt, uint64 flaggedAt, uint8 moduleTypes) {
function check(address module, bytes32 data) external view returns (uint64 listedAt, uint64 flaggedAt) {
ModuleInfo memory moduleInfo = listedModules[module];
listedAt = moduleInfo.listedAt;
flaggedAt = moduleInfo.flaggedAt;
moduleTypes = moduleInfo.moduleTypes;

// If moduleType is not permitted, return 0 for listedAt and flaggedAt.
if (moduleInfo.moduleTypes & uint8(uint256(data)) == 0) {
listedAt = 0;
flaggedAt = 0;
}
}

/**
Expand Down
8 changes: 4 additions & 4 deletions contracts/base/RegistryManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ contract RegistryManager is Ownable2Step, OnlyAccountCallable {

event RegistryChanged(address indexed oldRegistry, address indexed newRegistry);

error ModuleNotPermitted(address plugin, uint64 listedAt, uint64 flaggedAt, uint8 moduleTypes);
error ModuleNotPermitted(address plugin, uint64 listedAt, uint64 flaggedAt, uint8 moduleType);
error ContractDoesNotImplementValidInterfaceId(address account);

modifier onlyPermittedModule(address module, uint8 moduleType) {
Expand Down Expand Up @@ -39,9 +39,9 @@ contract RegistryManager is Ownable2Step, OnlyAccountCallable {
*/
function checkPermittedModule(address module, uint8 moduleType) internal view {
// Only allow registered and non-flagged modules
(uint64 listedAt, uint64 flaggedAt, uint8 allowedModuleType) = ISafeProtocolRegistry(registry).check(module);
if (listedAt == 0 || flaggedAt != 0 || allowedModuleType & moduleType != moduleType) {
revert ModuleNotPermitted(module, listedAt, flaggedAt, allowedModuleType);
(uint64 listedAt, uint64 flaggedAt) = ISafeProtocolRegistry(registry).check(module, bytes32(uint256(moduleType)));
if (listedAt == 0 || flaggedAt != 0) {
revert ModuleNotPermitted(module, listedAt, flaggedAt, moduleType);
}
}

Expand Down
10 changes: 5 additions & 5 deletions contracts/interfaces/Registry.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ pragma solidity ^0.8.18;
import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol";

interface ISafeProtocolRegistry is IERC165 {
/** @param module Address of the module that should be checked
/**
* @notice This function allows external contracts to check if a module is listed and not flagged as faulty in the registry.
* @param module Address of the module that should be checked
* @param data bytes32 providing more information about the module. The type of this parameter is bytes32 to provide the flexibility to the developers to interpret the value in the registry. For example, it can be moduleType and registry would then check if given address can be used as that type of module.
* @return listedAt MUST return the block number when the module was listed in the registry (or 0 if not listed)
* @return flaggedAt MUST return the block number when the module was listed in the flagged as faulty (or 0 if not flagged)
* @return moduleTypes uint8 indicating the types of module that the contract can be used as in the protocol.
* The value is a bitwise OR of the module types. For example, if the module can be used as a plugin and
* a function handler, the value will be 2^0 (Plugin) + 2^1 (Function Handler) = 3.
*/
function check(address module) external view returns (uint64 listedAt, uint64 flaggedAt, uint8 moduleTypes);
function check(address module, bytes32 data) external view returns (uint64 listedAt, uint64 flaggedAt);
}
11 changes: 7 additions & 4 deletions test/FunctionHandlerManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ describe("FunctionHandler", async () => {

await expect(account.executeCallViaMock(account.target, 0, dataSetFunctionHandler, MaxUint256))
.to.be.revertedWithCustomError(functionHandlerManager, "ModuleNotPermitted")
.withArgs(user1.address, 0, 0, 0);
.withArgs(user1.address, 0, 0, 2);
});

it("Should not allow invalid module type as function handler", async () => {
it("Should not allow hooks module type as function handler", async () => {
const { functionHandlerManager, safeProtocolRegistry, account } = await setupTests();
const module = await getHooksWithPassingChecks();
await safeProtocolRegistry.connect(owner).addModule(module.target, MODULE_TYPE_HOOKS);
Expand All @@ -113,11 +113,14 @@ describe("FunctionHandler", async () => {
module.target,
]);

const moduleInfo = await safeProtocolRegistry.check(module.target);
const moduleInfo = await safeProtocolRegistry.check(
module.target,
hre.ethers.encodeBytes32String(MODULE_TYPE_FUNCTION_HANDLER.toString()),
);

await expect(account.executeCallViaMock(account.target, 0, dataSetFunctionHandler, MaxUint256))
.to.be.revertedWithCustomError(functionHandlerManager, "ModuleNotPermitted")
.withArgs(module.target, moduleInfo.listedAt, 0, MODULE_TYPE_HOOKS);
.withArgs(module.target, moduleInfo.listedAt, 0, MODULE_TYPE_FUNCTION_HANDLER);
});

it("Should revert with FunctionHandlerNotSet when function handler is not enabled", async () => {
Expand Down
2 changes: 1 addition & 1 deletion test/SafeProtocolManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ describe("SafeProtocolManager", async () => {
]);
await expect(account.exec(await safeProtocolManager.getAddress(), 0, data))
.to.be.revertedWithCustomError(safeProtocolManager, "ModuleNotPermitted")
.withArgs(pluginAddress, 0, 0, 0);
.withArgs(pluginAddress, 0, 0, MODULE_TYPE_PLUGIN);
});

it("Should not allow an Account to enable plugin that does not support ERC165", async () => {
Expand Down
48 changes: 31 additions & 17 deletions test/SafeProtocolRegistry.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ describe("SafeProtocolRegistry", async () => {
return { safeProtocolRegistry, mockFunctionHandlerAddress };
});

// A helper function to convert a number to a bytes32 value
const numberToBytes32 = (value: bigint) => hre.ethers.zeroPadValue(hre.ethers.toBeHex(value), 32);

it("Should allow adding a module only once", async () => {
const { safeProtocolRegistry } = await setupTests();
const mockHookAddress = (await getHooksWithPassingChecks()).target;
Expand All @@ -34,24 +37,23 @@ describe("SafeProtocolRegistry", async () => {
const mockModule = await (await hre.ethers.getContractFactory("MockContract")).deploy();
await mockModule.givenMethodReturnBool("0x01ffc9a7", true);

await safeProtocolRegistry.connect(owner).addModule(mockModule, MODULE_TYPE_PLUGIN + MODULE_TYPE_FUNCTION_HANDLER);

const [listedAt, flaggedAt, moduleTypes] = await safeProtocolRegistry.check.staticCall(mockModule.target);
expect(listedAt).to.be.greaterThan(0);
expect(flaggedAt).to.be.equal(0);
expect(moduleTypes).to.be.equal(MODULE_TYPE_PLUGIN + MODULE_TYPE_FUNCTION_HANDLER);

const mockModule2 = await (await hre.ethers.getContractFactory("MockContract")).deploy();
await mockModule2.givenMethodReturnBool("0x01ffc9a7", true);

await safeProtocolRegistry
.connect(owner)
.addModule(mockModule2, MODULE_TYPE_PLUGIN + MODULE_TYPE_FUNCTION_HANDLER + MODULE_TYPE_HOOKS);
.addModule(mockModule, MODULE_TYPE_PLUGIN + MODULE_TYPE_FUNCTION_HANDLER + MODULE_TYPE_HOOKS);
const [listedAt, flaggedAt] = await safeProtocolRegistry.check.staticCall(
mockModule.target,
numberToBytes32(MODULE_TYPE_FUNCTION_HANDLER),
);
expect(listedAt).to.be.greaterThan(0);
expect(flaggedAt).to.be.equal(0);

const [listedAt2, flaggedAt2, moduleTypes2] = await safeProtocolRegistry.check.staticCall(mockModule2.target);
const [listedAt2, flaggedAt2] = await safeProtocolRegistry.check.staticCall(mockModule.target, numberToBytes32(MODULE_TYPE_PLUGIN));
expect(listedAt2).to.be.greaterThan(0);
expect(flaggedAt2).to.be.equal(0);
expect(moduleTypes2).to.be.equal(MODULE_TYPE_PLUGIN + MODULE_TYPE_FUNCTION_HANDLER + MODULE_TYPE_HOOKS);

const [listedAt3, flaggedAt3] = await safeProtocolRegistry.check.staticCall(mockModule.target, numberToBytes32(MODULE_TYPE_HOOKS));
expect(listedAt3).to.be.greaterThan(0);
expect(flaggedAt3).to.be.equal(0);
});

it("Should not allow adding a module with invalid moduleTypes", async () => {
Expand Down Expand Up @@ -89,7 +91,8 @@ describe("SafeProtocolRegistry", async () => {

expect(await safeProtocolRegistry.connect(owner).flagModule(mockHookAddress));

const [flaggedAt] = await safeProtocolRegistry.check.staticCall(mockHookAddress);
const [listedAt, flaggedAt] = await safeProtocolRegistry.check.staticCall(mockHookAddress, numberToBytes32(MODULE_TYPE_HOOKS));
expect(listedAt).to.be.gt(0);
expect(flaggedAt).to.be.gt(0);
});

Expand All @@ -111,16 +114,27 @@ describe("SafeProtocolRegistry", async () => {

it("Should return (0,0,0) for non-listed module", async () => {
const { safeProtocolRegistry } = await setupTests();
const [listedAt, flaggedAt, moduleTypes] = await safeProtocolRegistry.check.staticCall(AddressZero);

const [listedAt, flaggedAt] = await safeProtocolRegistry.check.staticCall(AddressZero, numberToBytes32(MODULE_TYPE_PLUGIN));
expect(listedAt).to.be.equal(0);
expect(flaggedAt).to.be.equal(0);
expect(moduleTypes).to.be.equal(0);

const [listedAt2, flaggedAt2] = await safeProtocolRegistry.check.staticCall(
AddressZero,
numberToBytes32(MODULE_TYPE_FUNCTION_HANDLER),
);
expect(listedAt2).to.be.equal(0);
expect(flaggedAt2).to.be.equal(0);

const [listedAt3, flaggedAt3] = await safeProtocolRegistry.check.staticCall(AddressZero, numberToBytes32(MODULE_TYPE_HOOKS));
expect(listedAt3).to.be.equal(0);
expect(flaggedAt3).to.be.equal(0);
});

it("Should return true when valid interfaceId is passed", async () => {
const { safeProtocolRegistry } = await setupTests();
expect(await safeProtocolRegistry.supportsInterface.staticCall("0x01ffc9a7")).to.be.true;
expect(await safeProtocolRegistry.supportsInterface.staticCall("0xc23697a8")).to.be.true;
expect(await safeProtocolRegistry.supportsInterface.staticCall("0x253bd7b7")).to.be.true;
});

it("Should return false when invalid interfaceId is passed", async () => {
Expand Down

0 comments on commit 2e70453

Please sign in to comment.