Skip to content

Commit

Permalink
fix: suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
advaita-saha committed Feb 12, 2024
1 parent 973e0aa commit f76f222
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 13 deletions.
21 changes: 18 additions & 3 deletions constantine/math/arithmetic/finite_fields_square_root.nim
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,28 @@ func invsqrt*[C](r: var Fp[C], a: Fp[C]) =
r.invsqrt_p3mod4(a)
elif C.has_P_5mod8_primeModulus():
r.invsqrt_p5mod8(a)
elif C == Bandersnatch or C == Banderwagon:
r.inv_sqrt_precomp_vartime(a) # should be changed
else:
r.invsqrt_tonelli_shanks(a)

func invsqrt_vartime*[C](r: var Fp[C], a: Fp[C]) =
r.inv_sqrt_precomp_vartime(a)
## Compute the inverse square root of ``a``
##
## This requires ``a`` to be a square
##
## The result is undefined otherwise
##
## The square root, if it exist is multivalued,
## i.e. both x² == (-x)²
## This procedure returns a deterministic result
## This procedure is NOT constant-time
when C.has_P_3mod4_primeModulus():
r.invsqrt_p3mod4(a)
elif C.has_P_5mod8_primeModulus():
r.invsqrt_p5mod8(a)
elif C == Bandersnatch or C == Banderwagon:
r.inv_sqrt_precomp_vartime(a)
else:
r.invsqrt_tonelli_shanks(a)

func sqrt*[C](a: var Fp[C]) =
## Compute the square root of ``a``
Expand Down
24 changes: 16 additions & 8 deletions constantine/math/constants/bandersnatch_sqrt.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ import
../io/[io_bigints, io_fields],
../arithmetic/finite_fields

const
# with e = 2adicity
# p == s * 2^e + 1
# root_of_unity = smallest_quadratic_nonresidue^s
# exponent = (p-1-2^e)/2^e / 2
Bandersnatch_TonelliShanks_exponent* = BigInt[222].fromHex"0x39f6d3a994cebea4199cec0404d0ec02a9ded2017fff2dff7fffffff"
Bandersnatch_TonelliShanks_twoAdicity* = 32
Bandersnatch_TonelliShanks_root_of_unity* = Fp[Bandersnatch].fromHex"0x212d79e5b416b6f0fd56dc8d168d6c0c4024ff270b3e0941b788f500b912f1f"

# ############################################################
#
# Specialized Tonelli-Shanks for Bandersnatch
Expand Down Expand Up @@ -148,15 +157,14 @@ func precompute_tonelli_shanks_addchain*(
r.square()
r *= a

const
# with e = 2adicity
# p == s * 2^e + 1
# root_of_unity = smallest_quadratic_nonresidue^s
# exponent = (p-1-2^e)/2^e / 2
Bandersnatch_TonelliShanks_exponent* = BigInt[222].fromHex"0x39f6d3a994cebea4199cec0404d0ec02a9ded2017fff2dff7fffffff"
Bandersnatch_TonelliShanks_twoAdicity* = 32
Bandersnatch_TonelliShanks_root_of_unity* = Fp[Bandersnatch].fromHex"0x212d79e5b416b6f0fd56dc8d168d6c0c4024ff270b3e0941b788f500b912f1f"

# ############################################################
#
# Optimized square-root via Discrete Log lookup tables
#
# ############################################################

const
Bandersnatch_SqrtDlog_TotalBits* = Bandersnatch_TonelliShanks_twoAdicity
Bandersnatch_SqrtDlog_BlockSize* = 8
Bandersnatch_SqrtDlog_Blocks* = 4 #Bandersnatch_TonelliShanks_sqrtParam_TotalBits / Bandersnatch_TonelliShanks_sqrtParam_BlockSize
Expand Down
3 changes: 2 additions & 1 deletion constantine/math/constants/banderwagon_sqrt.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1496,4 +1496,5 @@ func precompute_tonelli_shanks_addchain*(
# 261 + 3 = 264 operations
r *= x11111111
r.square()
r *= a
r *= a

55 changes: 54 additions & 1 deletion constantine/serialization/codecs_banderwagon.nim
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,44 @@ func deserialize_unchecked*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccS
var x{.noInit.}: Fp[Banderwagon]
x.fromBig(t)

let onCurve = dst.trySetFromCoordX_vartime(x) # later to be shifted to a constant time version
let onCurve = dst.trySetFromCoordX(x)
if not(bool onCurve):
return cttCodecEcc_PointNotOnCurve

let isLexicographicallyLargest = dst.y.toBig() >= Fp[Banderwagon].getPrimeMinus1div2()
dst.y.cneg(not isLexicographicallyLargest)

return cttCodecEcc_Success

func deserialize_unchecked_vartime*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =
## This is not in constant-time
## Deserialize a Banderwagon point (x, y) in format
##
## if y is not lexicographically largest
## set y -> -y
##
## Returns cttCodecEcc_Success if successful
## https://hackmd.io/@6iQDuIePQjyYBqDChYw_jg/BJBNcv9fq#Serialisation
# If infinity, src must be all zeros
var check: bool = true
for i in 0 ..< src.len:
if src[i] != byte 0:
check = false
break
if check:
dst.setInf()
return cttCodecEcc_PointAtInfinity

var t{.noInit.}: matchingBigInt(Banderwagon)
t.unmarshal(src, bigEndian)

if bool(t >= Banderwagon.Mod()):
return cttCodecEcc_CoordinateGreaterThanOrEqualModulus

var x{.noInit.}: Fp[Banderwagon]
x.fromBig(t)

let onCurve = dst.trySetFromCoordX_vartime(x)
if not(bool onCurve):
return cttCodecEcc_PointNotOnCurve

Expand All @@ -148,6 +185,22 @@ func deserialize*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =

return cttCodecEcc_Success

func deserialize_vartime*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =
## Deserialize a Banderwagon point (x, y) in format
##
## Also checks if the point lies in the banderwagon scheme subgroup
##
## Returns cttCodecEcc_Success if successful
## Returns cttCodecEcc_PointNotInSubgroup if doesn't lie in subgroup
result = deserialize_unchecked_vartime(dst, src)
if result != cttCodecEcc_Success:
return result

if not(bool dst.isInSubgroup()):
return cttCodecEcc_PointNotInSubgroup

return cttCodecEcc_Success

## ############################################################
##
## Banderwagon Scalar Serialization
Expand Down
21 changes: 21 additions & 0 deletions tests/t_ethereum_verkle_primitives.nim
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,27 @@ suite "Banderwagon Serialization Tests":
doAssert (point == points[i]).bool(), "Decoded Element is different from expected element"

testDeserialization(expected_bit_strings.len)

## Check decoding if it is as expected or not ( vartime impl )
test "vartime - Decoding Each bit string":
proc testDeserialization_vartime(len: int) =
# Checks if the point serialized in the previous
# tests matches with the deserialization of expected strings
for i, bit_string in expected_bit_strings:

# converts serialized value in hex to byte array
var arr: Bytes
discard arr.parseHex(bit_string)

# deserialization from expected bits
var point{.noInit.}: EC
let stat = point.deserialize_vartime(arr)

# Assertion check for the Deserialization Success & correctness
doAssert stat == cttCodecEcc_Success, "Deserialization Failed"
doAssert (point == points[i]).bool(), "Decoded Element is different from expected element"

testDeserialization_vartime(expected_bit_strings.len)

# Check if the subgroup check is working on eliminating
# points which don't lie on banderwagon, while
Expand Down

0 comments on commit f76f222

Please sign in to comment.