Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqrt-via-dlog: cleanup and 20% accel for vartime Bandersnatch/Banderwagon deserialization #362

Merged
merged 2 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions benchmarks/bench_fields_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,33 @@ proc sqrtRatioBench*(T: typedesc, iters: int) =
bench("Fused SquareRoot+Division+isSquare sqrt(u/v)", T, iters):
let isSquare = r.sqrt_ratio_if_square(u, v)

proc sqrtVartimeBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)

const algoType = block:
when T.C.has_P_3mod4_primeModulus():
"p ≡ 3 (mod 4)"
elif T.C.has_P_5mod8_primeModulus():
"p ≡ 5 (mod 8)"
else:
"Tonelli-Shanks"
const addchain = block:
when T.C.hasSqrtAddchain() or T.C.hasTonelliShanksAddchain():
"with addition chain"
else:
"without addition chain"
const desc = "Square Root (vartime " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x
discard r.sqrt_if_square_vartime()

proc sqrtRatioVartimeBench*(T: typedesc, iters: int) =
var r: T
let u = rng.random_unsafe(T)
let v = rng.random_unsafe(T)
bench("Fused SquareRoot+Division+isSquare sqrt_vartime(u/v)", T, iters):
let isSquare = r.sqrt_ratio_if_square_vartime(u, v)

proc powBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/bench_fp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ proc main() =
isSquareBench(Fp[curve], ExponentIters)
sqrtBench(Fp[curve], ExponentIters)
sqrtRatioBench(Fp[curve], ExponentIters)
when curve == Bandersnatch:
sqrtVartimeBench(Fp[curve], ExponentIters)
sqrtRatioVartimeBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters)
powUnsafeBench(Fp[curve], ExponentIters)
Expand Down
56 changes: 45 additions & 11 deletions constantine/math/arithmetic/finite_fields_square_root.nim
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func invsqrt_vartime*[C](r: var Fp[C], a: Fp[C]) =
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
##
## This procedure is NOT constant-time
when C.has_P_3mod4_primeModulus():
r.invsqrt_p3mod4(a)
Expand All @@ -250,9 +251,13 @@ func sqrt*[C](a: var Fp[C]) =
a *= t

func sqrt_vartime*[C](a: var Fp[C]) =
## This is a vartime version of sqrt
## It is not constant-time
## This has the precomp optimisation
## Compute the square root of ``a``
##
## This requires ``a`` to be a square
##
## The result is undefined otherwise
##
## This is NOT constant-time
var t {.noInit.}: Fp[C]
t.invsqrt_vartime(a)
a *= t
Expand All @@ -271,8 +276,13 @@ func sqrt_invsqrt*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
sqrt.prod(invsqrt, a)

func sqrt_invsqrt_vartime*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
## It is not constant-time
## This has the precomp optimisation
## Compute the square root of ``a`` and inverse square root of ``a``
##
## This requires ``a`` to be a square
##
## The result is undefined otherwise
##
## This is NOT constant-time
invsqrt.invsqrt_vartime(a)
sqrt.prod(invsqrt, a)

Expand All @@ -292,8 +302,13 @@ func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool
result = test == a

func sqrt_invsqrt_if_square_vartime*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
## It is not constant-time
## This has the precomp optimisation
## Compute the square root and ivnerse square root of ``a``
##
## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root
##
## The result is undefined otherwise
##
## This is NOT constant-time
sqrt_invsqrt_vartime(sqrt, invsqrt, a)
var test {.noInit.}: Fp[C]
test.square(sqrt)
Expand All @@ -311,6 +326,19 @@ func sqrt_if_square*[C](a: var Fp[C]): SecretBool =
result = sqrt_invsqrt_if_square(sqrt, invsqrt, a)
a = sqrt

func sqrt_if_square_vartime*[C](a: var Fp[C]): SecretBool =
## If ``a`` is a square, compute the square root of ``a``
## if not, ``a`` is undefined.
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
##
## This is NOT constant-time
var sqrt{.noInit.}, invsqrt{.noInit.}: Fp[C]
result = sqrt_invsqrt_if_square_vartime(sqrt, invsqrt, a)
a = sqrt

func invsqrt_if_square*[C](r: var Fp[C], a: Fp[C]): SecretBool =
## If ``a`` is a square, compute the inverse square root of ``a``
## if not, ``a`` is undefined.
Expand All @@ -323,8 +351,10 @@ func invsqrt_if_square*[C](r: var Fp[C], a: Fp[C]): SecretBool =
result = sqrt_invsqrt_if_square(sqrt, r, a)

func invsqrt_if_square_vartime*[C](r: var Fp[C], a: Fp[C]): SecretBool =
## It is not constant-time
## This has the precomp optimisation
## If ``a`` is a square, compute the inverse square root of ``a``
## if not, ``a`` is undefined.
##
## This procedure is NOT constant-time
var sqrt{.noInit.}: Fp[C]
result = sqrt_invsqrt_if_square_vartime(sqrt, r, a)

Expand Down Expand Up @@ -365,8 +395,12 @@ func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} =
r *= u # √u/√v

func sqrt_ratio_if_square_vartime*(r: var Fp, u, v: Fp): SecretBool {.inline.} =
## It is not constant-time
## This has the precomp optimisation
## If u/v is a square, compute √(u/v)
## if not, the result is undefined
##
## r must not alias u or v
##
## This is NOT constant-time
var uv{.noInit.}: Fp
uv.prod(u, v) # uv
result = r.invsqrt_if_square_vartime(uv) # 1/√uv
Expand Down
140 changes: 23 additions & 117 deletions constantine/math/arithmetic/finite_fields_square_root_precomp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -31,133 +31,33 @@ func sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime(x: Fp): int {.tags:[VarTime],
let key = cast[int](x.mres.limbs[0] and SecretWord 0xFFFF)
return Fp.C.sqrtDlog(dlogLUT).getOrDefault(key, 0)


# sqrtAlg_GetPrecomputedRootOfUnity sets target to g^(multiplier << (order * sqrtParam_BlockSize)), where g is the fixed primitive 2^32th root of unity.
#
# We assume that order 0 <= order*sqrtParam_BlockSize <= 32 and that multiplier is in [0, 1 <<sqrtParam_BlockSize)
func sqrtAlg_GetPrecomputedRootOfUnity(target: var Fp, multiplier: int, order: uint) =
target = Fp.C.sqrtDlog(PrecomputedBlocks)[order][multiplier]


func sqrtAlg_ComputeRelevantPowers(z: Fp, squareRootCandidate: var Fp, rootOfUnity: var Fp) {.addchain.} =
## sliding window-type algorithm with window-size 5
## Note that we precompute and use z^255 multiple times (even though it's not size 5)
## and some windows actually overlap
var z2, z3, z7, z6, z9, z11, z13, z19, z21, z25, z27, z29, z31, z255 {.noInit.} : Fp
var acc: Fp
z2.square(z)
z3.prod(z2, z)
z6.prod(z3, z3)
z7.prod(z6, z)
z9.prod(z7, z2)
z11.prod(z9, z2)
z13.prod(z11, z2)
z19.prod(z13, z6)
z21.prod(z19, z2)
z25.prod(z19, z6)
z27.prod(z25, z2)
z29.prod(z27, z2)
z31.prod(z29, z2)
acc.prod(z27, z29)
acc.prod(acc, acc)
acc.prod(acc, acc)
z255.prod(acc, z31)
acc.prod(acc, acc)
acc.prod(acc, acc)
acc.prod(acc, z31)
acc.square_repeated(6)
acc.prod(acc, z27)
acc.square_repeated(6)
acc.prod(acc, z19)
acc.square_repeated(5)
acc.prod(acc, z21)
acc.square_repeated(7)
acc.prod(acc, z25)
acc.square_repeated(6)
acc.prod(acc, z19)
acc.square_repeated(5)
acc.prod(acc, z7)
acc.square_repeated(5)
acc.prod(acc, z11)
acc.square_repeated(5)
acc.prod(acc, z29)
acc.square_repeated(5)
acc.prod(acc, z9)
acc.square_repeated(7)
acc.prod(acc, z3)
acc.square_repeated(7)
acc.prod(acc, z25)
acc.square_repeated(5)
acc.prod(acc, z25)
acc.square_repeated(5)
acc.prod(acc, z27)
acc.square_repeated(8)
acc.prod(acc, z)
acc.square_repeated(8)
acc.prod(acc, z)
acc.square_repeated(6)
acc.prod(acc, z13)
acc.square_repeated(7)
acc.prod(acc, z7)
acc.square_repeated(3)
acc.prod(acc, z3)
acc.square_repeated(13)
acc.prod(acc, z21)
acc.square_repeated(5)
acc.prod(acc, z9)
acc.square_repeated(5)
acc.prod(acc, z27)
acc.square_repeated(5)
acc.prod(acc, z27)
acc.square_repeated(5)
acc.prod(acc, z9)
acc.square_repeated(10)
acc.prod(acc, z)
acc.square_repeated(7)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
acc.square_repeated(6)
acc.prod(acc, z11)
acc.square_repeated(9)
acc.prod(acc, z255)
acc.square_repeated(2)
acc.prod(acc, z)
acc.square_repeated(7)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
# acc is now z^((BaseFieldMultiplicativeOddOrder - 1)/2)
rootOfUnity.square(acc)
rootOfUnity *= z
squareRootCandidate.prod(acc, z)


func invSqrtEqDyadic_vartime*(z: var Fp) =
## The algorithm works by essentially computing the dlog of z and then halving it.
## negExponent is intended to hold the negative of the dlog of z.
func invSqrtEqDyadic_vartime*(a: var Fp) =
## The algorithm works by essentially computing the dlog of a and then halving it.
## negExponent is intended to hold the negative of the dlog of a.
## We determine this 32-bit value (usually) _sqrtBlockSize many bits at a time, starting with the least-significant bits.
##
## If _sqrtBlockSize does not divide 32, the *first* iteration will determine fewer bits.

var negExponent: int
var temp, temp2: Fp

# set powers[i] to z^(1<< (i*blocksize))
# set powers[i] to a^(1<< (i*blocksize))
var powers: array[4, Fp]
powers[0] = z
powers[0] = a
for i in 1 ..< Fp.C.sqrtDlog(Blocks):
powers[i] = powers[i - 1]
for j in 0 ..< Fp.C.sqrtDlog(BlockSize):
powers[i].square(powers[i])

## looking at the dlogs, powers[i] is essentially the wanted exponent, left-shifted by i*_sqrtBlockSize and taken mod 1<<32
## dlogHighDyadicRootNeg essentially (up to sign) reads off the _sqrtBlockSize many most significant bits. (returned as low-order bits)
##
##
## first iteration may be slightly special if BlockSize does not divide 32
negExponent = sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime(powers[Fp.C.sqrtDlog(Blocks) - 1])
negExponent = negExponent shr Fp.C.sqrtDlog(FirstBlockUnusedBits)
Expand All @@ -173,21 +73,27 @@ func invSqrtEqDyadic_vartime*(z: var Fp) =
for j in 0 ..< i:
sqrtAlg_GetPrecomputedRootOfUnity(temp, int( (negExponent shr (j*Fp.C.sqrtDlog(BlockSize))) and Fp.C.sqrtDlog(BitMask) ), uint(j + Fp.C.sqrtDlog(Blocks) - 1 - i))
temp2.prod(temp2, temp)

var newBits = sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime(temp2)
negExponent = negExponent or (newBits shl ((i*Fp.C.sqrtDlog(BlockSize)) - Fp.C.sqrtDlog(FirstBlockUnusedBits)))

negExponent = negExponent shr 1
z.setOne()
a.setOne()

for i in 0 ..< Fp.C.sqrtDlog(Blocks):
sqrtAlg_GetPrecomputedRootOfUnity(temp, int((negExponent shr (i*Fp.C.sqrtDlog(BlockSize))) and Fp.C.sqrtDlog(BitMask)), uint(i))
z.prod(z, temp)

func inv_sqrt_precomp_vartime*(dst: var Fp, x: Fp) {.inline.} =
dst.setOne()
var candidate, rootOfUnity {.noInit.}: Fp
sqrtAlg_ComputeRelevantPowers(x, candidate, rootOfUnity)
invSqrtEqDyadic_vartime(rootOfUnity)
dst.prod(candidate, rootOfUnity)
dst.inv()
a.prod(a, temp)

func inv_sqrt_precomp_vartime*(r: var Fp, a: Fp) =
var candidate, powLargestPowerOfTwo {.noInit.}: Fp
# Compute
# candidate = a^((q-1-2^e)/(2*2^e))
# with
# s and e, precomputed values
# such as q == s * 2^e + 1 the field modulus
# e is the 2-adicity of the field (the 2^e is the largest power of two that divides q-1)
candidate.precompute_tonelli_shanks_addchain(a)
powLargestPowerOfTwo.square(candidate)
powLargestPowerOfTwo *= a
invSqrtEqDyadic_vartime(powLargestPowerOfTwo)
r.prod(candidate, powLargestPowerOfTwo)
8 changes: 4 additions & 4 deletions constantine/math/constants/bandersnatch_sqrt.nim
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ func precompute_tonelli_shanks_addchain*(
r *= a


# ############################################################
#
# Optimized square-root via Discrete Log lookup tables
#
# ############################################################
#
# Optimized square-root via Discrete Log lookup tables
#
# ############################################################

const
Expand Down
Loading