Skip to content

Commit

Permalink
Move boolToUint to SafeCast
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestognw committed Feb 6, 2024
1 parent 8417107 commit 3ae1b4a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
47 changes: 19 additions & 28 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pragma solidity ^0.8.20;

import {Address} from "../Address.sol";
import {Panic} from "../Panic.sol";
import {SafeCast} from "./SafeCast.sol";

/**
* @dev Standard math utilities missing in the Solidity language.
Expand Down Expand Up @@ -210,7 +211,7 @@ library Math {
* @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
*/
function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
return mulDiv(x, y, denominator) + boolToUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
}

/**
Expand Down Expand Up @@ -379,7 +380,7 @@ library Math {
function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
unchecked {
uint256 result = sqrt(a);
return result + boolToUint(unsignedRoundsUp(rounding) && result * result < a);
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a);
}
}

Expand All @@ -391,35 +392,35 @@ library Math {
uint256 result = 0;
uint256 exp;
unchecked {
exp = 128 * boolToUint(value > (1 << 128) - 1);
exp = 128 * SafeCast.toUint(value > (1 << 128) - 1);
value >>= exp;
result += exp;

exp = 64 * boolToUint(value > (1 << 64) - 1);
exp = 64 * SafeCast.toUint(value > (1 << 64) - 1);
value >>= exp;
result += exp;

exp = 32 * boolToUint(value > (1 << 32) - 1);
exp = 32 * SafeCast.toUint(value > (1 << 32) - 1);
value >>= exp;
result += exp;

exp = 16 * boolToUint(value > (1 << 16) - 1);
exp = 16 * SafeCast.toUint(value > (1 << 16) - 1);
value >>= exp;
result += exp;

exp = 8 * boolToUint(value > (1 << 8) - 1);
exp = 8 * SafeCast.toUint(value > (1 << 8) - 1);
value >>= exp;
result += exp;

exp = 4 * boolToUint(value > (1 << 4) - 1);
exp = 4 * SafeCast.toUint(value > (1 << 4) - 1);
value >>= exp;
result += exp;

exp = 2 * boolToUint(value > (1 << 2) - 1);
exp = 2 * SafeCast.toUint(value > (1 << 2) - 1);
value >>= exp;
result += exp;

result += boolToUint(value > 1);
result += SafeCast.toUint(value > 1);
}
return result;
}
Expand All @@ -431,7 +432,7 @@ library Math {
function log2(uint256 value, Rounding rounding) internal pure returns (uint256) {
unchecked {
uint256 result = log2(value);
return result + boolToUint(unsignedRoundsUp(rounding) && 1 << result < value);
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value);
}
}

Expand Down Expand Up @@ -480,7 +481,7 @@ library Math {
function log10(uint256 value, Rounding rounding) internal pure returns (uint256) {
unchecked {
uint256 result = log10(value);
return result + boolToUint(unsignedRoundsUp(rounding) && 10 ** result < value);
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value);
}
}

Expand All @@ -494,23 +495,23 @@ library Math {
uint256 result = 0;
uint256 isGt;
unchecked {
isGt = boolToUint(value > (1 << 128) - 1);
isGt = SafeCast.toUint(value > (1 << 128) - 1);
value >>= isGt * 128;
result += isGt * 16;

isGt = boolToUint(value > (1 << 64) - 1);
isGt = SafeCast.toUint(value > (1 << 64) - 1);
value >>= isGt * 64;
result += isGt * 8;

isGt = boolToUint(value > (1 << 32) - 1);
isGt = SafeCast.toUint(value > (1 << 32) - 1);
value >>= isGt * 32;
result += isGt * 4;

isGt = boolToUint(value > (1 << 16) - 1);
isGt = SafeCast.toUint(value > (1 << 16) - 1);
value >>= isGt * 16;
result += isGt * 2;

result += boolToUint(value > (1 << 8) - 1);
result += SafeCast.toUint(value > (1 << 8) - 1);
}
return result;
}
Expand All @@ -522,7 +523,7 @@ library Math {
function log256(uint256 value, Rounding rounding) internal pure returns (uint256) {
unchecked {
uint256 result = log256(value);
return result + boolToUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value);
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value);
}
}

Expand All @@ -532,14 +533,4 @@ library Math {
function unsignedRoundsUp(Rounding rounding) internal pure returns (bool) {
return uint8(rounding) % 2 == 1;
}

/**
* @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
*/
function boolToUint(bool b) internal pure returns (uint256 u) {
/// @solidity memory-safe-assembly
assembly {
u := iszero(iszero(b))
}
}
}
10 changes: 10 additions & 0 deletions contracts/utils/math/SafeCast.sol
Original file line number Diff line number Diff line change
Expand Up @@ -1150,4 +1150,14 @@ library SafeCast {
}
return int256(value);
}

/**
* @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
*/
function toUint(bool b) internal pure returns (uint256 u) {
/// @solidity memory-safe-assembly
assembly {
u := iszero(iszero(b))
}
}
}
16 changes: 14 additions & 2 deletions scripts/generate/templates/SafeCast.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const header = `\
pragma solidity ^0.8.20;
/**
* @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow
* @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow
* checks.
*
* Downcasting from uint256/int256 in Solidity does not revert on overflow. This can
Expand Down Expand Up @@ -116,11 +116,23 @@ function toUint${length}(int${length} value) internal pure returns (uint${length
}
`;

const boolToUint = `
/**
* @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
*/
function toUint(bool b) internal pure returns (uint256 u) {
/// @solidity memory-safe-assembly
assembly {
u := iszero(iszero(b))
}
}
`;

// GENERATE
module.exports = format(
header.trimEnd(),
'library SafeCast {',
errors,
[...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256)],
[...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256), boolToUint],
'}',
);
10 changes: 0 additions & 10 deletions test/utils/math/Math.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,4 @@ describe('Math', function () {
});
});
});

describe('boolToUint', function () {
it('boolToUint(false) should be 0', async function () {
expect(await this.mock.$boolToUint(false)).to.equal(0n);
});

it('boolToUint(true) should be 1', async function () {
expect(await this.mock.$boolToUint(true)).to.equal(1n);
});
});
});
10 changes: 10 additions & 0 deletions test/utils/math/SafeCast.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,14 @@ describe('SafeCast', function () {
.withArgs(ethers.MaxUint256);
});
});

describe('toUint (bool)', function () {
it('toUint(false) should be 0', async function () {
expect(await this.mock.$toUint(false)).to.equal(0n);
});

it('toUint(true) should be 1', async function () {
expect(await this.mock.$toUint(true)).to.equal(1n);
});
});
});

0 comments on commit 3ae1b4a

Please sign in to comment.