Skip to content

Commit

Permalink
sqrt-via-dlog: cleanup and 20% accel for vartime Bandersnatch/Banderw…
Browse files Browse the repository at this point in the history
…agon deserialization (#362)

* sqrt-via-dlog: cleanup and 20% accel for Verkle Trees

* sqrt-vartime: Benches + comment improvement
  • Loading branch information
mratsim authored Feb 17, 2024
1 parent 6978344 commit 008fe10
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 132 deletions.
27 changes: 27 additions & 0 deletions benchmarks/bench_fields_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,33 @@ proc sqrtRatioBench*(T: typedesc, iters: int) =
bench("Fused SquareRoot+Division+isSquare sqrt(u/v)", T, iters):
let isSquare = r.sqrt_ratio_if_square(u, v)

proc sqrtVartimeBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)

const algoType = block:
when T.C.has_P_3mod4_primeModulus():
"p ≡ 3 (mod 4)"
elif T.C.has_P_5mod8_primeModulus():
"p ≡ 5 (mod 8)"
else:
"Tonelli-Shanks"
const addchain = block:
when T.C.hasSqrtAddchain() or T.C.hasTonelliShanksAddchain():
"with addition chain"
else:
"without addition chain"
const desc = "Square Root (vartime " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x
discard r.sqrt_if_square_vartime()

proc sqrtRatioVartimeBench*(T: typedesc, iters: int) =
var r: T
let u = rng.random_unsafe(T)
let v = rng.random_unsafe(T)
bench("Fused SquareRoot+Division+isSquare sqrt_vartime(u/v)", T, iters):
let isSquare = r.sqrt_ratio_if_square_vartime(u, v)

proc powBench*(T: typedesc, iters: int) =
let x = rng.random_unsafe(T)
let exponent = rng.random_unsafe(BigInt[T.C.getCurveOrderBitwidth()])
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/bench_fp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ proc main() =
isSquareBench(Fp[curve], ExponentIters)
sqrtBench(Fp[curve], ExponentIters)
sqrtRatioBench(Fp[curve], ExponentIters)
when curve == Bandersnatch:
sqrtVartimeBench(Fp[curve], ExponentIters)
sqrtRatioVartimeBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters)
powUnsafeBench(Fp[curve], ExponentIters)
Expand Down
56 changes: 45 additions & 11 deletions constantine/math/arithmetic/finite_fields_square_root.nim
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func invsqrt_vartime*[C](r: var Fp[C], a: Fp[C]) =
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
##
## This procedure is NOT constant-time
when C.has_P_3mod4_primeModulus():
r.invsqrt_p3mod4(a)
Expand All @@ -250,9 +251,13 @@ func sqrt*[C](a: var Fp[C]) =
a *= t

func sqrt_vartime*[C](a: var Fp[C]) =
## This is a vartime version of sqrt
## It is not constant-time
## This has the precomp optimisation
## Compute the square root of ``a``
##
## This requires ``a`` to be a square
##
## The result is undefined otherwise
##
## This is NOT constant-time
var t {.noInit.}: Fp[C]
t.invsqrt_vartime(a)
a *= t
Expand All @@ -271,8 +276,13 @@ func sqrt_invsqrt*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
sqrt.prod(invsqrt, a)

func sqrt_invsqrt_vartime*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]) =
## It is not constant-time
## This has the precomp optimisation
## Compute the square root of ``a`` and inverse square root of ``a``
##
## This requires ``a`` to be a square
##
## The result is undefined otherwise
##
## This is NOT constant-time
invsqrt.invsqrt_vartime(a)
sqrt.prod(invsqrt, a)

Expand All @@ -292,8 +302,13 @@ func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool
result = test == a

func sqrt_invsqrt_if_square_vartime*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
## It is not constant-time
## This has the precomp optimisation
## Compute the square root and ivnerse square root of ``a``
##
## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root
##
## The result is undefined otherwise
##
## This is NOT constant-time
sqrt_invsqrt_vartime(sqrt, invsqrt, a)
var test {.noInit.}: Fp[C]
test.square(sqrt)
Expand All @@ -311,6 +326,19 @@ func sqrt_if_square*[C](a: var Fp[C]): SecretBool =
result = sqrt_invsqrt_if_square(sqrt, invsqrt, a)
a = sqrt

func sqrt_if_square_vartime*[C](a: var Fp[C]): SecretBool =
## If ``a`` is a square, compute the square root of ``a``
## if not, ``a`` is undefined.
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
##
## This is NOT constant-time
var sqrt{.noInit.}, invsqrt{.noInit.}: Fp[C]
result = sqrt_invsqrt_if_square_vartime(sqrt, invsqrt, a)
a = sqrt

func invsqrt_if_square*[C](r: var Fp[C], a: Fp[C]): SecretBool =
## If ``a`` is a square, compute the inverse square root of ``a``
## if not, ``a`` is undefined.
Expand All @@ -323,8 +351,10 @@ func invsqrt_if_square*[C](r: var Fp[C], a: Fp[C]): SecretBool =
result = sqrt_invsqrt_if_square(sqrt, r, a)

func invsqrt_if_square_vartime*[C](r: var Fp[C], a: Fp[C]): SecretBool =
## It is not constant-time
## This has the precomp optimisation
## If ``a`` is a square, compute the inverse square root of ``a``
## if not, ``a`` is undefined.
##
## This procedure is NOT constant-time
var sqrt{.noInit.}: Fp[C]
result = sqrt_invsqrt_if_square_vartime(sqrt, r, a)

Expand Down Expand Up @@ -365,8 +395,12 @@ func sqrt_ratio_if_square*(r: var Fp, u, v: Fp): SecretBool {.inline.} =
r *= u # √u/√v

func sqrt_ratio_if_square_vartime*(r: var Fp, u, v: Fp): SecretBool {.inline.} =
## It is not constant-time
## This has the precomp optimisation
## If u/v is a square, compute √(u/v)
## if not, the result is undefined
##
## r must not alias u or v
##
## This is NOT constant-time
var uv{.noInit.}: Fp
uv.prod(u, v) # uv
result = r.invsqrt_if_square_vartime(uv) # 1/√uv
Expand Down
140 changes: 23 additions & 117 deletions constantine/math/arithmetic/finite_fields_square_root_precomp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -31,133 +31,33 @@ func sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime(x: Fp): int {.tags:[VarTime],
let key = cast[int](x.mres.limbs[0] and SecretWord 0xFFFF)
return Fp.C.sqrtDlog(dlogLUT).getOrDefault(key, 0)


# sqrtAlg_GetPrecomputedRootOfUnity sets target to g^(multiplier << (order * sqrtParam_BlockSize)), where g is the fixed primitive 2^32th root of unity.
#
# We assume that order 0 <= order*sqrtParam_BlockSize <= 32 and that multiplier is in [0, 1 <<sqrtParam_BlockSize)
func sqrtAlg_GetPrecomputedRootOfUnity(target: var Fp, multiplier: int, order: uint) =
target = Fp.C.sqrtDlog(PrecomputedBlocks)[order][multiplier]


func sqrtAlg_ComputeRelevantPowers(z: Fp, squareRootCandidate: var Fp, rootOfUnity: var Fp) {.addchain.} =
## sliding window-type algorithm with window-size 5
## Note that we precompute and use z^255 multiple times (even though it's not size 5)
## and some windows actually overlap
var z2, z3, z7, z6, z9, z11, z13, z19, z21, z25, z27, z29, z31, z255 {.noInit.} : Fp
var acc: Fp
z2.square(z)
z3.prod(z2, z)
z6.prod(z3, z3)
z7.prod(z6, z)
z9.prod(z7, z2)
z11.prod(z9, z2)
z13.prod(z11, z2)
z19.prod(z13, z6)
z21.prod(z19, z2)
z25.prod(z19, z6)
z27.prod(z25, z2)
z29.prod(z27, z2)
z31.prod(z29, z2)
acc.prod(z27, z29)
acc.prod(acc, acc)
acc.prod(acc, acc)
z255.prod(acc, z31)
acc.prod(acc, acc)
acc.prod(acc, acc)
acc.prod(acc, z31)
acc.square_repeated(6)
acc.prod(acc, z27)
acc.square_repeated(6)
acc.prod(acc, z19)
acc.square_repeated(5)
acc.prod(acc, z21)
acc.square_repeated(7)
acc.prod(acc, z25)
acc.square_repeated(6)
acc.prod(acc, z19)
acc.square_repeated(5)
acc.prod(acc, z7)
acc.square_repeated(5)
acc.prod(acc, z11)
acc.square_repeated(5)
acc.prod(acc, z29)
acc.square_repeated(5)
acc.prod(acc, z9)
acc.square_repeated(7)
acc.prod(acc, z3)
acc.square_repeated(7)
acc.prod(acc, z25)
acc.square_repeated(5)
acc.prod(acc, z25)
acc.square_repeated(5)
acc.prod(acc, z27)
acc.square_repeated(8)
acc.prod(acc, z)
acc.square_repeated(8)
acc.prod(acc, z)
acc.square_repeated(6)
acc.prod(acc, z13)
acc.square_repeated(7)
acc.prod(acc, z7)
acc.square_repeated(3)
acc.prod(acc, z3)
acc.square_repeated(13)
acc.prod(acc, z21)
acc.square_repeated(5)
acc.prod(acc, z9)
acc.square_repeated(5)
acc.prod(acc, z27)
acc.square_repeated(5)
acc.prod(acc, z27)
acc.square_repeated(5)
acc.prod(acc, z9)
acc.square_repeated(10)
acc.prod(acc, z)
acc.square_repeated(7)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
acc.square_repeated(6)
acc.prod(acc, z11)
acc.square_repeated(9)
acc.prod(acc, z255)
acc.square_repeated(2)
acc.prod(acc, z)
acc.square_repeated(7)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
acc.square_repeated(8)
acc.prod(acc, z255)
# acc is now z^((BaseFieldMultiplicativeOddOrder - 1)/2)
rootOfUnity.square(acc)
rootOfUnity *= z
squareRootCandidate.prod(acc, z)


func invSqrtEqDyadic_vartime*(z: var Fp) =
## The algorithm works by essentially computing the dlog of z and then halving it.
## negExponent is intended to hold the negative of the dlog of z.
func invSqrtEqDyadic_vartime*(a: var Fp) =
## The algorithm works by essentially computing the dlog of a and then halving it.
## negExponent is intended to hold the negative of the dlog of a.
## We determine this 32-bit value (usually) _sqrtBlockSize many bits at a time, starting with the least-significant bits.
##
## If _sqrtBlockSize does not divide 32, the *first* iteration will determine fewer bits.

var negExponent: int
var temp, temp2: Fp

# set powers[i] to z^(1<< (i*blocksize))
# set powers[i] to a^(1<< (i*blocksize))
var powers: array[4, Fp]
powers[0] = z
powers[0] = a
for i in 1 ..< Fp.C.sqrtDlog(Blocks):
powers[i] = powers[i - 1]
for j in 0 ..< Fp.C.sqrtDlog(BlockSize):
powers[i].square(powers[i])

## looking at the dlogs, powers[i] is essentially the wanted exponent, left-shifted by i*_sqrtBlockSize and taken mod 1<<32
## dlogHighDyadicRootNeg essentially (up to sign) reads off the _sqrtBlockSize many most significant bits. (returned as low-order bits)
##
##
## first iteration may be slightly special if BlockSize does not divide 32
negExponent = sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime(powers[Fp.C.sqrtDlog(Blocks) - 1])
negExponent = negExponent shr Fp.C.sqrtDlog(FirstBlockUnusedBits)
Expand All @@ -173,21 +73,27 @@ func invSqrtEqDyadic_vartime*(z: var Fp) =
for j in 0 ..< i:
sqrtAlg_GetPrecomputedRootOfUnity(temp, int( (negExponent shr (j*Fp.C.sqrtDlog(BlockSize))) and Fp.C.sqrtDlog(BitMask) ), uint(j + Fp.C.sqrtDlog(Blocks) - 1 - i))
temp2.prod(temp2, temp)

var newBits = sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime(temp2)
negExponent = negExponent or (newBits shl ((i*Fp.C.sqrtDlog(BlockSize)) - Fp.C.sqrtDlog(FirstBlockUnusedBits)))

negExponent = negExponent shr 1
z.setOne()
a.setOne()

for i in 0 ..< Fp.C.sqrtDlog(Blocks):
sqrtAlg_GetPrecomputedRootOfUnity(temp, int((negExponent shr (i*Fp.C.sqrtDlog(BlockSize))) and Fp.C.sqrtDlog(BitMask)), uint(i))
z.prod(z, temp)

func inv_sqrt_precomp_vartime*(dst: var Fp, x: Fp) {.inline.} =
dst.setOne()
var candidate, rootOfUnity {.noInit.}: Fp
sqrtAlg_ComputeRelevantPowers(x, candidate, rootOfUnity)
invSqrtEqDyadic_vartime(rootOfUnity)
dst.prod(candidate, rootOfUnity)
dst.inv()
a.prod(a, temp)

func inv_sqrt_precomp_vartime*(r: var Fp, a: Fp) =
var candidate, powLargestPowerOfTwo {.noInit.}: Fp
# Compute
# candidate = a^((q-1-2^e)/(2*2^e))
# with
# s and e, precomputed values
# such as q == s * 2^e + 1 the field modulus
# e is the 2-adicity of the field (the 2^e is the largest power of two that divides q-1)
candidate.precompute_tonelli_shanks_addchain(a)
powLargestPowerOfTwo.square(candidate)
powLargestPowerOfTwo *= a
invSqrtEqDyadic_vartime(powLargestPowerOfTwo)
r.prod(candidate, powLargestPowerOfTwo)
8 changes: 4 additions & 4 deletions constantine/math/constants/bandersnatch_sqrt.nim
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ func precompute_tonelli_shanks_addchain*(
r *= a


# ############################################################
#
# Optimized square-root via Discrete Log lookup tables
#
# ############################################################
#
# Optimized square-root via Discrete Log lookup tables
#
# ############################################################

const
Expand Down

0 comments on commit 008fe10

Please sign in to comment.