From 3ae1b4a4910d7207013e419294a52ef8eeb9e9bd Mon Sep 17 00:00:00 2001 From: ernestognw Date: Tue, 6 Feb 2024 14:52:40 -0600 Subject: [PATCH] Move `boolToUint` to `SafeCast` --- contracts/utils/math/Math.sol | 47 +++++++++++--------------- contracts/utils/math/SafeCast.sol | 10 ++++++ scripts/generate/templates/SafeCast.js | 16 +++++++-- test/utils/math/Math.test.js | 10 ------ test/utils/math/SafeCast.test.js | 10 ++++++ 5 files changed, 53 insertions(+), 40 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 4e8e7b2c095..be05506ec53 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -5,6 +5,7 @@ pragma solidity ^0.8.20; import {Address} from "../Address.sol"; import {Panic} from "../Panic.sol"; +import {SafeCast} from "./SafeCast.sol"; /** * @dev Standard math utilities missing in the Solidity language. @@ -210,7 +211,7 @@ library Math { * @dev Calculates x * y / denominator with full precision, following the selected rounding direction. */ function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) { - return mulDiv(x, y, denominator) + boolToUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0); + return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0); } /** @@ -379,7 +380,7 @@ library Math { function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = sqrt(a); - return result + boolToUint(unsignedRoundsUp(rounding) && result * result < a); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a); } } @@ -391,35 +392,35 @@ library Math { uint256 result = 0; uint256 exp; unchecked { - exp = 128 * boolToUint(value > (1 << 128) - 1); + exp = 128 * SafeCast.toUint(value > (1 << 128) - 1); value >>= exp; result += exp; - exp = 64 * boolToUint(value > (1 << 64) - 1); + exp = 64 * SafeCast.toUint(value > (1 << 64) - 1); value >>= exp; result += exp; - exp = 32 * boolToUint(value > (1 << 32) - 1); + exp = 32 * SafeCast.toUint(value > (1 << 32) - 1); value >>= exp; result += exp; - exp = 16 * boolToUint(value > (1 << 16) - 1); + exp = 16 * SafeCast.toUint(value > (1 << 16) - 1); value >>= exp; result += exp; - exp = 8 * boolToUint(value > (1 << 8) - 1); + exp = 8 * SafeCast.toUint(value > (1 << 8) - 1); value >>= exp; result += exp; - exp = 4 * boolToUint(value > (1 << 4) - 1); + exp = 4 * SafeCast.toUint(value > (1 << 4) - 1); value >>= exp; result += exp; - exp = 2 * boolToUint(value > (1 << 2) - 1); + exp = 2 * SafeCast.toUint(value > (1 << 2) - 1); value >>= exp; result += exp; - result += boolToUint(value > 1); + result += SafeCast.toUint(value > 1); } return result; } @@ -431,7 +432,7 @@ library Math { function log2(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log2(value); - return result + boolToUint(unsignedRoundsUp(rounding) && 1 << result < value); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value); } } @@ -480,7 +481,7 @@ library Math { function log10(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log10(value); - return result + boolToUint(unsignedRoundsUp(rounding) && 10 ** result < value); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value); } } @@ -494,23 +495,23 @@ library Math { uint256 result = 0; uint256 isGt; unchecked { - isGt = boolToUint(value > (1 << 128) - 1); + isGt = SafeCast.toUint(value > (1 << 128) - 1); value >>= isGt * 128; result += isGt * 16; - isGt = boolToUint(value > (1 << 64) - 1); + isGt = SafeCast.toUint(value > (1 << 64) - 1); value >>= isGt * 64; result += isGt * 8; - isGt = boolToUint(value > (1 << 32) - 1); + isGt = SafeCast.toUint(value > (1 << 32) - 1); value >>= isGt * 32; result += isGt * 4; - isGt = boolToUint(value > (1 << 16) - 1); + isGt = SafeCast.toUint(value > (1 << 16) - 1); value >>= isGt * 16; result += isGt * 2; - result += boolToUint(value > (1 << 8) - 1); + result += SafeCast.toUint(value > (1 << 8) - 1); } return result; } @@ -522,7 +523,7 @@ library Math { function log256(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log256(value); - return result + boolToUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value); } } @@ -532,14 +533,4 @@ library Math { function unsignedRoundsUp(Rounding rounding) internal pure returns (bool) { return uint8(rounding) % 2 == 1; } - - /** - * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump. - */ - function boolToUint(bool b) internal pure returns (uint256 u) { - /// @solidity memory-safe-assembly - assembly { - u := iszero(iszero(b)) - } - } } diff --git a/contracts/utils/math/SafeCast.sol b/contracts/utils/math/SafeCast.sol index 0ed458b43c2..3063e80de69 100644 --- a/contracts/utils/math/SafeCast.sol +++ b/contracts/utils/math/SafeCast.sol @@ -1150,4 +1150,14 @@ library SafeCast { } return int256(value); } + + /** + * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump. + */ + function toUint(bool b) internal pure returns (uint256 u) { + /// @solidity memory-safe-assembly + assembly { + u := iszero(iszero(b)) + } + } } diff --git a/scripts/generate/templates/SafeCast.js b/scripts/generate/templates/SafeCast.js index f1954a7533f..a10ee75c975 100644 --- a/scripts/generate/templates/SafeCast.js +++ b/scripts/generate/templates/SafeCast.js @@ -7,7 +7,7 @@ const header = `\ pragma solidity ^0.8.20; /** - * @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow + * @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow * checks. * * Downcasting from uint256/int256 in Solidity does not revert on overflow. This can @@ -116,11 +116,23 @@ function toUint${length}(int${length} value) internal pure returns (uint${length } `; +const boolToUint = ` + /** + * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump. + */ + function toUint(bool b) internal pure returns (uint256 u) { + /// @solidity memory-safe-assembly + assembly { + u := iszero(iszero(b)) + } + } +`; + // GENERATE module.exports = format( header.trimEnd(), 'library SafeCast {', errors, - [...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256)], + [...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256), boolToUint], '}', ); diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index bb565e8ae98..2762fcc5732 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -512,14 +512,4 @@ describe('Math', function () { }); }); }); - - describe('boolToUint', function () { - it('boolToUint(false) should be 0', async function () { - expect(await this.mock.$boolToUint(false)).to.equal(0n); - }); - - it('boolToUint(true) should be 1', async function () { - expect(await this.mock.$boolToUint(true)).to.equal(1n); - }); - }); }); diff --git a/test/utils/math/SafeCast.test.js b/test/utils/math/SafeCast.test.js index ecf55dc35a2..aa609faf0ac 100644 --- a/test/utils/math/SafeCast.test.js +++ b/test/utils/math/SafeCast.test.js @@ -146,4 +146,14 @@ describe('SafeCast', function () { .withArgs(ethers.MaxUint256); }); }); + + describe('toUint (bool)', function () { + it('toUint(false) should be 0', async function () { + expect(await this.mock.$toUint(false)).to.equal(0n); + }); + + it('toUint(true) should be 1', async function () { + expect(await this.mock.$toUint(true)).to.equal(1n); + }); + }); });