Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dettman Multiplication Arithmetic #1500

Merged
merged 13 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions src/Arithmetic/Core.v
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Require Import Crypto.Arithmetic.ModularArithmeticTheorems.
Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.ListUtil.
Import Crypto.Util.ListUtil.Reifiable.
Require Import Crypto.Util.NatUtil.
Require Import Crypto.Util.Prod.
Require Import Crypto.Util.Decidable.Bool2Prop.
Expand Down Expand Up @@ -197,6 +198,46 @@ Module Associational.
now rewrite eval_reduce_adjusted.
Qed.

Definition split_one (s:Z) (w fw : Z) (p:list (Z*Z)) :=
let hi_lo := partition (fun t => (fst t =? w)) p in
(snd hi_lo, map (fun t => (fst t / fw, snd t)) (fst hi_lo)).

Lemma eval_split_one s w fw p (s_nz:s<>0) (fw_nz:fw<>0) (w_fw : w mod fw = 0) (fw_s : fw mod s = 0):
Associational.eval (fst (split_one s w fw p)) + fw * Associational.eval (snd (split_one s w fw p)) = Associational.eval p.
Proof.
remember (Z_div_exact_full_2 _ _ fw_nz w_fw) as H2.
clear HeqH2 fw_nz w_fw.
induction p as [|t p' IHp'].
- simpl. cbv [Associational.eval]. simpl. lia.
- cbv [split_one]. simpl. destruct (fst t =? w) eqn:E.
+ simpl in IHp'. remember (partition (fun t0 : Z * Z => fst t0 =? w) p') as thing.
destruct thing as [thing1 thing2]. simpl. simpl in IHp'. repeat rewrite Associational.eval_cons.
ring_simplify. simpl.
apply Z.eqb_eq in E. rewrite E. rewrite <- H2. rewrite <- IHp'. ring.
+ simpl in IHp'. remember (partition (fun t0 : Z * Z => fst t0 =? w) p') as thing.
destruct thing as [thing1 thing2]. simpl. simpl in IHp'. repeat rewrite Associational.eval_cons.
rewrite <- IHp'. ring.
Qed.

Definition reduce_one (s:Z) (w fw : Z) (c: Z) (p:list _) : list (Z*Z) :=
let lo_hi := split_one s w fw p in
fst lo_hi ++ map (fun thing => (fst thing, snd thing * (c * (fw / s)))) (snd lo_hi).

Lemma eval_map_mul_snd (x:Z) (p:list (Z*Z))
: Associational.eval (List.map (fun t => (fst t, snd t * x)) p) = x * Associational.eval p.
Proof. induction p; push; nsatz. Qed.

Lemma eval_reduce_one s w fw c p (s_nz:s<>0) (fw_nz:fw<>0) (w_fw : w mod fw = 0) (fw_s : fw mod s = 0)
(modulus_nz: s - c<>0) :
Associational.eval (reduce_one s w fw c p) mod (s - c) =
Associational.eval p mod (s - c).
Proof using Type.
cbv [reduce_one]; push.
rewrite eval_map_mul_snd. rewrite <- Z.mul_assoc.
rewrite <- (reduction_rule _ _ _ _ modulus_nz).
rewrite Z.mul_assoc. rewrite <- (Z_div_exact_full_2 fw s s_nz fw_s). rewrite eval_split_one; trivial.
Qed.

(*
Definition splitQ (s:Q) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z)
:= let hi_lo := partition (fun t => (fst t * Zpos (Qden s)) mod (Qnum s) =? 0) p in
Expand Down Expand Up @@ -486,6 +527,58 @@ Module Associational.
push; [|rewrite IHp]; reflexivity.
Qed.

Definition value_at_weight (a : list (Z * Z)) (d : Z) :=
fold_right Z.add 0 (map snd (filter (fun p => fst p =? d) a)).

Lemma value_at_weight_works a d : d * (value_at_weight a d) = Associational.eval (filter (fun p => fst p =? d) a).
Proof.
induction a as [| a0 a' IHa'].
- cbv [Associational.eval value_at_weight]. simpl. lia.
- cbv [value_at_weight]. simpl. destruct (fst a0 =? d) eqn:E.
+ rewrite Associational.eval_cons. simpl. rewrite <- IHa'. cbv [value_at_weight]. lia.
+ apply IHa'.
Qed.

Lemma not_in_value_0 a d : ~ In d (map fst a) -> value_at_weight a d = 0.
Proof.
intros H. induction a as [| x a' IHa'].
- reflexivity.
- cbv [value_at_weight]. simpl. destruct (fst x =? d) eqn:E.
+ exfalso. apply H. simpl. lia.
+ apply IHa'. intros H'. apply H. simpl. right. apply H'.
Qed.

Definition dedup_weights a :=
map (fun d => (d, value_at_weight a d)) (nodupb Z.eqb (map fst a)).

Lemma funs_same (l : list Z) (a0 : Z*Z) (a' : list (Z*Z)) :
~ In (fst a0) l ->
forall d, In d l ->
(fun d : Z => (d, value_at_weight (a0 :: a') d)) d = (fun d => (d, value_at_weight a' d)) d.
Proof.
intros H d H'. simpl. f_equal. cbv [value_at_weight]. simpl. destruct (fst a0 =? d) eqn:E.
- exfalso. rewrite Z.eqb_eq in E. subst. apply (H H').
- reflexivity.
Qed.

Lemma eval_dedup_weights a : Associational.eval (dedup_weights a) = Associational.eval a.
Proof.
induction a as [| a0 a' IHa'].
- reflexivity.
- cbv [dedup_weights]. simpl. destruct (existsb (Z.eqb (fst a0)) (nodupb Z.eqb (map fst a'))) eqn:E.
+ apply (existsb_eqb_true_iff Z.eqb Z.eqb_eq) in E. rewrite <- (nodupb_in_iff Z.eqb Z.eqb_eq) in E.
apply (nodupb_split Z.eqb Z.eqb_eq) in E. destruct E as [l1 [l2 [H1 [H2 H3] ] ] ]. rewrite H1.
repeat rewrite map_app. rewrite (map_ext_in _ _ l1 (funs_same l1 a0 a' H2)).
rewrite (map_ext_in _ _ l2 (funs_same l2 a0 a' H3)). repeat rewrite Associational.eval_app. simpl.
repeat rewrite Associational.eval_cons. simpl. rewrite <- IHa'. simpl. rewrite Associational.eval_nil.
cbv [dedup_weights]. rewrite H1. repeat rewrite map_app. repeat rewrite Associational.eval_app.
cbv [value_at_weight]. simpl. rewrite Z.eqb_refl. simpl. cbv [Associational.eval]. simpl. lia.
+ simpl. apply (existsb_eqb_false_iff Z.eqb Z.eqb_eq) in E. rewrite (map_ext_in _ _ _ (funs_same _ _ _ E)).
repeat rewrite Associational.eval_cons. simpl. rewrite <- IHa'. cbv [dedup_weights]. f_equal. f_equal.
rewrite <- (nodupb_in_iff Z.eqb Z.eqb_eq) in E. cbv [value_at_weight]. simpl. rewrite Z.eqb_refl.
apply not_in_value_0 in E. cbv [value_at_weight] in E. simpl. rewrite E. lia.
Qed.

Section Carries.
Definition carryterm (w fw:Z) (t:Z * Z) :=
if (Z.eqb (fst t) w)
Expand All @@ -511,6 +604,35 @@ Module Associational.
eval (carry w fw p) = eval p.
Proof using Type*. cbv [carry]; induction p; push; nsatz. Qed.
Hint Rewrite eval_carry using auto : push_eval.

Definition borrowterm (w fw:Z) (t:Z * Z) :=
let quot := w / fw in
if (Z.eqb (fst t) w)
then [(quot, snd t * fw)]
else [t].

Lemma eval_borrowterm w fw (t:Z * Z) (fw_nz:fw<>0) (w_fw:w mod fw = 0) :
Associational.eval (borrowterm w fw t) = Associational.eval [t].
Proof using Type*.
cbv [borrowterm Let_In]; break_match; push; [|trivial].
pose proof (Z.div_mod (snd t) fw fw_nz).
rewrite Z.eqb_eq in *.
ring_simplify. rewrite Z.mul_comm. rewrite Z.mul_assoc. rewrite <- Z_div_exact_full_2; lia.
Qed.

Definition borrow (w fw:Z) (p:list (Z*Z)) :=
flat_map (borrowterm w fw) p.

Lemma eval_borrow w fw p (fw_nz:fw<>0) (w_fw:w mod fw = 0):
Associational.eval (borrow w fw p) = Associational.eval p.
Proof using Type*.
cbv [borrow borrowterm]. induction p as [| a p' IHp'].
- trivial.
- push. destruct (fst a =? w) eqn:E.
+ rewrite Z.mul_comm. rewrite <- Z.mul_assoc. rewrite <- Z_div_exact_full_2; lia.
+ rewrite IHp'. lia.
Qed.

End Carries.
End Associational.

Expand Down
157 changes: 157 additions & 0 deletions src/Arithmetic/DettmanMultiplication.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
Require Import Crypto.Arithmetic.Core.
Require Import Coq.ZArith.ZArith Coq.micromega.Lia.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ZUtil.Modulo.PullPush.
Local Open Scope list_scope.

Import Associational Positional.
Import ListNotations. Local Open Scope Z_scope.

Local Coercion Z.of_nat : nat >-> Z.

Section __.

Context
(e : nat)
(c_ : list (Z*Z))
(p_nz : 2 ^ e - Associational.eval c_ <> 0)
(limbs : nat)
(limb_size : nat)
(limbs_gteq_3 : (3 <= limbs)%nat)
(e_small : (e <= limb_size * limbs)%nat)
(e_big : (limb_size * (limbs - 1) <= e)%nat).

Let s := (2 ^ e).

Let c := Associational.eval c_.

Let base := (2 ^ limb_size).

Lemma base_nz : base <> 0.
Proof. cbv [base]. apply Z.pow_nonzero; lia. Qed.

Lemma s_nz : s <> 0.
Proof. cbv [s]. apply Z.pow_nonzero; lia. Qed.

Let weight (n : nat) := base ^ n.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Lemma weight_0 : weight 0 = 1.
Proof. reflexivity. Qed.

Lemma weight_nz : forall i, weight i <> 0.
Proof. intros i. cbv [weight]. apply Z.pow_nonzero; lia. Qed.

Lemma mod_is_zero : forall b (n m : nat), b <> 0 -> le n m -> (b ^ m) mod (b ^ n) = 0.
intros b n m H1 H2. induction H2 as [|m' nlem' IHm'].
- rewrite Z_mod_same_full. constructor.
- rewrite Nat2Z.inj_succ. cbv [Z.succ]. rewrite <- Pow.Z.pow_mul_base by lia.
rewrite Z.mul_mod_full. rewrite IHm'. rewrite Z.mul_0_r. apply Z.mod_0_l. apply Z.pow_nonzero; lia.
Qed.

Lemma div_nz a b : b > 0 -> b <= a -> a / b <> 0.
Proof.
intros H1 H2. assert (H: 1 <= a / b).
- replace 1 with (b / b).
+ apply Z_div_le; assumption.
+ apply Z_div_same. apply H1.
- symmetry. apply Z.lt_neq. apply Z.lt_le_trans with 1.
+ reflexivity.
+ apply H.
Qed.

Lemma limbs_mod_s_0 : (weight limbs) mod s = 0.
Proof.
cbv [weight base s]. rewrite <- Z.pow_mul_r by lia. rewrite <- Nat2Z.inj_mul. apply mod_is_zero; lia.
Qed.

Local Open Scope nat_scope.

Definition reduce' x1 x2 x3 x4 x5 := dedup_weights (reduce_one x1 x2 x3 x4 x5).
Definition carry' x1 x2 x3 := dedup_weights (Associational.carry x1 x2 x3).

Definition loop_body i before :=
let middle1 := carry' (weight (i + limbs)) (weight 1) before in
let middle2 := reduce' s (weight (i + limbs)) (weight limbs) c middle1 in
let after := carry' (weight i) (weight 1) middle2 in
after.

Hint Rewrite eval_reduce_one Associational.eval_carry eval_dedup_weights: push_eval.

Lemma eval_loop_body i before :
(Associational.eval (loop_body i before) mod (s - c) =
Associational.eval before mod (s - c))%Z.
Proof.
cbv [loop_body carry' reduce']. autorewrite with push_eval. reflexivity.
- apply weight_nz.
- apply s_nz.
- apply weight_nz.
- cbv [weight]. apply mod_is_zero.
+ apply base_nz.
+ lia.
- cbv [weight]. apply limbs_mod_s_0.
- apply p_nz.
- apply weight_nz.
Qed.

Definition loop start :=
fold_right loop_body start (rev (seq 1 (limbs - 2 - 1))).

Lemma eval_loop start :
((Associational.eval (loop start)) mod (s - c) = (Associational.eval start) mod (s - c))%Z.
Proof.
cbv [loop]. induction (rev (seq 1 (limbs - 2 - 1))) as [| i l' IHl'].
- reflexivity.
- simpl. rewrite eval_loop_body. apply IHl'.
Qed.

Definition mulmod a b :=
let l := limbs in
let a_assoc := Positional.to_associational weight limbs a in
let b_assoc := Positional.to_associational weight limbs b in
let r0 := Associational.mul a_assoc b_assoc in
let r0' := dedup_weights r0 in
let r1 := carry' (weight (2 * l - 2)) (weight 1) r0' in
let r2 := reduce' s (weight (2 * l - 2)) (weight l) c r1 in
let r3 := carry' (weight (l - 2)) (weight 1) r2 in
let r4 := reduce' s (weight (2 * l - 1)) (weight l) c r3 in
let r5 := carry' (weight (l - 1)) (weight 1) r4 in
let r6 := carry' (weight (l - 1)) (Z.div s (weight (l - 1))) r5 in
let r7 := carry' (weight l) (weight 1) r6 in
let r8 := borrow (weight l) (weight l / s) r7 in
let r8' := dedup_weights r8 in
let r9 := reduce' s s s c r8' in
let r10 := carry' (weight 0) (weight 1) r9 in
let r11 := loop r10 in
let r12 := reduce' s (weight (2 * l - 2)) (weight l) c r11 in
let r13 := carry' (weight (l - 2)) (weight 1) r12 in
Positional.from_associational weight l r13.

Hint Rewrite Positional.eval_from_associational Positional.eval_to_associational eval_borrow eval_loop: push_eval.

Local Open Scope Z_scope.

Theorem eval_mulmod a b :
(Positional.eval weight limbs a * Positional.eval weight limbs b) mod (s - c) =
(Positional.eval weight limbs (mulmod a b)) mod (s - c).
Proof.
cbv [mulmod carry' reduce']. autorewrite with push_eval. reflexivity.
all:
cbv [weight s base];
try apply weight_nz;
try apply s_nz;
try apply p_nz;
try apply weight_0;
try apply Z_mod_same_full;
try apply limbs_mod_s_0;
try apply mod_is_zero;
try (remember limbs_gteq_3 as H; lia);
try apply base_nz.
- apply div_nz; try lia. rewrite <- Z.pow_mul_r by lia. rewrite <- Nat2Z.inj_mul.
rewrite <- Z.pow_le_mono_r_iff by lia. lia.
- apply div_nz; try lia. rewrite <- Z.pow_mul_r by lia. rewrite <- Nat2Z.inj_mul.
rewrite <- Z.pow_le_mono_r_iff by lia. lia.
- repeat rewrite <- Z.pow_mul_r by lia. rewrite <- Z.pow_sub_r by lia.
rewrite <- Nat2Z.inj_mul. rewrite <- Nat2Z.inj_sub by lia. apply mod_is_zero; lia.
Qed.

End __.
Loading