diff --git a/_CoqProject b/_CoqProject index 69fdc4dace..ad5fd28b5e 100644 --- a/_CoqProject +++ b/_CoqProject @@ -7,6 +7,7 @@ src/BaseSystem.v src/BaseSystemProofs.v src/EdDSARepChange.v src/MxDHRepChange.v +src/NewBaseSystem.v src/Testbit.v src/Algebra/ZToRing.v src/Assembly/Bounds.v @@ -438,6 +439,7 @@ src/Util/AdditionChainExponentiation.v src/Util/AutoRewrite.v src/Util/Bool.v src/Util/CaseUtil.v +src/Util/CPSUtil.v src/Util/Curry.v src/Util/Decidable.v src/Util/Equality.v diff --git a/src/Karatsuba.v b/src/Karatsuba.v new file mode 100644 index 0000000000..47ae2facf5 --- /dev/null +++ b/src/Karatsuba.v @@ -0,0 +1,49 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Tactics.Algebra_syntax.Nsatz. +Require Import Crypto.Util.ZUtil. +Local Open Scope Z_scope. + +Section Karatsuba. + Context {T : Type} (eval : T -> Z) + (sub : T -> T -> T) + (eval_sub : forall x y, eval (sub x y) = eval x - eval y) + (mul : T -> T -> T) + (eval_mul : forall x y, eval (mul x y) = eval x * eval y) + (add : T -> T -> T) + (eval_add : forall x y, eval (add x y) = eval x + eval y) + (scmul : Z -> T -> T) + (eval_scmul : forall c x, eval (scmul c x) = c * eval x) + (split : Z -> T -> T * T) + (eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval x) + . + + Definition karatsuba_mul s (x y : T) : T := + let xab := split s x in + let yab := split s y in + let xy0 := mul (fst xab) (fst yab) in + let xy2 := mul (snd xab) (snd yab) in + let xy1 := sub (mul (add (fst xab) (snd xab)) (add (fst yab) (snd yab))) (add xy2 xy0) in + add (add (scmul (s^2) xy2) (scmul s xy1)) xy0. + + Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) : + eval (karatsuba_mul s x y) = eval x * eval y. + Proof. cbv [karatsuba_mul]; repeat rewrite ?eval_sub, ?eval_mul, ?eval_add, ?eval_scmul. + rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring. Qed. + + + Definition goldilocks_mul s (xs ys : T) : T := + let a_b := split s xs in + let c_d := split s ys in + let ac := mul (fst a_b) (fst c_d) in + (add (add ac (mul (snd a_b) (snd c_d))) + (scmul s (sub (mul (add (fst a_b) (snd a_b)) (add (fst c_d) (snd c_d))) ac))). + + Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper Z.modulo_equiv_modulo_Proper. + + Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys : + (eval (goldilocks_mul s xs ys)) mod p = (eval xs * eval ys) mod p. + Proof. cbv [goldilocks_mul]; Zmod_to_equiv_modulo. + repeat rewrite ?eval_mul, ?eval_add, ?eval_sub, ?eval_scmul, <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify. + setoid_rewrite s2_modp. + apply f_equal2; nsatz. Qed. +End Karatsuba. diff --git a/src/NewBaseSystem.v b/src/NewBaseSystem.v new file mode 100644 index 0000000000..549ec84a0f --- /dev/null +++ b/src/NewBaseSystem.v @@ -0,0 +1,458 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega. +Require Import Coq.ZArith.BinIntDef. +Local Open Scope Z_scope. + +Require Import Crypto.Tactics.Algebra_syntax.Nsatz. +Require Import Crypto.Util.Tactics Crypto.Util.Decidable Crypto.Util.LetIn. +Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma. +Require Import Crypto.Util.CPSUtil Crypto.Util.Prod. + +Require Import Coq.Lists.List. Import ListNotations. +Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple. + +Local Ltac prove_id := + repeat match goal with + | _ => progress intros + | _ => progress simpl + | _ => progress cbv [Let_In] + | _ => progress (autorewrite with uncps push_id in * ) + | _ => break_if + | _ => break_match + | _ => contradiction + | _ => reflexivity + | _ => nsatz + | _ => solve [auto] + end. + +Create HintDb push_basesystem_eval discriminated. +Local Ltac prove_eval := + repeat match goal with + | _ => progress intros + | _ => progress simpl + | _ => progress cbv [Let_In] + | _ => progress (autorewrite with push_basesystem_eval uncps push_id cancel_pair in * ) + | _ => break_if + | _ => break_match + | _ => split + | H : _ /\ _ |- _ => destruct H + | H : Some _ = Some _ |- _ => progress (inversion H; subst) + | _ => discriminate + | _ => reflexivity + | _ => nsatz + end. + +Delimit Scope runtime_scope with RT. +Definition runtime_mul := Z.mul. +Global Notation "a * b" := (runtime_mul a%RT b%RT) : runtime_scope. +Definition runtime_add := Z.add. +Global Notation "a + b" := (runtime_add a%RT b%RT) : runtime_scope. +Definition runtime_fst {A B} := @fst A B. +Definition runtime_snd {A B} := @snd A B. + +Module B. + Local Definition limb := (Z*Z)%type. (* position coefficient and run-time value *) + Module Associational. + Definition eval (p:list limb) : Z := + List.fold_right Z.add 0%Z (List.map (fun t => fst t * snd t) p). + + Lemma eval_nil : eval nil = 0. Proof. reflexivity. Qed. + Lemma eval_cons p q : eval (p::q) = (fst p) * (snd p) + eval q. Proof. reflexivity. Qed. + Lemma eval_app p q: eval (p++q) = eval p + eval q. + Proof. induction p; simpl eval; rewrite ?eval_nil, ?eval_cons; nsatz. Qed. + Hint Rewrite eval_nil eval_cons eval_app : push_basesystem_eval. + + Definition multerm (t t' : limb) : limb := + (fst t * fst t', (snd t * snd t')%RT). + Definition mul_cps (p q:list limb) {T} (f : list limb->T) := + flat_map_cps (fun t => @map_cps _ _ (multerm t) q) p f. + Definition mul (p q:list limb) := mul_cps p q id. + Hint Opaque mul : uncps. + Lemma eval_map_mul (a:limb) (q:list limb) : eval (List.map (multerm a) q) = fst a * snd a * eval q. + Proof. + induction q; cbv [multerm]; simpl List.map; + autorewrite with push_basesystem_eval cancel_pair; nsatz. + Qed. Hint Rewrite eval_map_mul : push_basesystem_eval. + Lemma mul_cps_id p q: forall {T} f, + @mul_cps p q T f = f (mul p q). + Proof. cbv [mul_cps mul]; prove_id. Qed. Hint Rewrite mul_cps_id : uncps. + Lemma eval_mul_noncps p q: + eval (mul p q) = eval p * eval q. + Proof. + cbv [mul mul_cps]; induction p; prove_eval. Qed. Hint Rewrite eval_mul_noncps : push_basesystem_eval. + + Fixpoint split (s:Z) (xs:list limb) + {T} (f :list limb*list limb->T) := + match xs with + | nil => f (nil, nil) + | cons x xs' => + split s xs' + (fun sxs' => + if dec (fst x mod s = 0) + then f (fst sxs', cons (fst x / s, snd x) (snd sxs')) + else f (cons x (fst sxs'), snd sxs')) + end. + Definition split_noncps s xs := split s xs id. + Hint Opaque split_noncps : uncps. + Lemma split_id s p: forall {T} f, + @split s p T f = f (split_noncps s p). + Proof. + induction p; + repeat match goal with + | _ => rewrite IHp + | _ => progress (cbv [split_noncps]; prove_id) + end. + Qed. Hint Rewrite split_id : uncps. + Lemma eval_split_noncps s p (s_nonzero:s<>0): + eval (fst (split_noncps s p)) + s*eval (snd (split_noncps s p)) = eval p. + Proof. + cbv [split_noncps]; induction p; prove_eval. + match goal with H:_ |- _ => + unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H) + end; nsatz. + Qed. Hint Rewrite @eval_split_noncps using auto : push_basesystem_eval. + + Definition reduce_cps (s:Z) (c:list limb) (p:list limb) + {T} (f : list limb->T) := + split s p (fun ab =>mul_cps c (snd ab) (fun rr =>f (fst ab ++ rr))). + Definition reduce s c p := reduce_cps s c p id. + Hint Opaque reduce : uncps. + Lemma reduction_rule a b s c (modulus_nonzero:s-c<>0) : + (a + s * b) mod (s - c) = (a + c * b) mod (s - c). + Proof. replace (a + s * b) with ((a + c*b) + b*(s-c)) by nsatz. + rewrite Z.add_mod, Z_mod_mult, Z.add_0_r, Z.mod_mod; trivial. Qed. + Lemma reduce_cps_id s c p {T} f: + @reduce_cps s c p T f = f (reduce s c p). + Proof. cbv [reduce_cps reduce]; prove_id. Qed. Hint Rewrite reduce_cps_id : uncps. + Lemma eval_reduce s c p (s_nonzero:s<>0) (modulus_nonzero:s-eval c<>0): + eval (reduce s c p) mod (s - eval c) = eval p mod (s - eval c). + Proof. + cbv [reduce reduce_cps]; prove_eval; + rewrite <-reduction_rule by auto; prove_eval. + Qed. Hint Rewrite eval_reduce : push_basesystem_eval. + + Section Carries. + Context {modulo div:Z->Z->Z}. + Context {div_mod : forall a b:Z, b <> 0 -> + a = b * (div a b) + modulo a b}. + + Definition carryterm_cps (w fw:Z) (t:limb) {T} (f:list limb->T) := + if dec (fst t = w) + then dlet d := div (snd t) fw in + dlet m := modulo (snd t) fw in + f ((w*fw, d) :: (w, m) :: @nil limb) + else f [t]. + Definition carry_cps(w fw:Z) (p:list limb) {T} (f:list limb->T) := + flat_map_cps (carryterm_cps w fw) p f. + Definition carryterm w fw t := carryterm_cps w fw t id. + Hint Opaque carryterm : uncps. + Definition carry w fw p := carry_cps w fw p id. + Hint Opaque carry : uncps. + Lemma carryterm_cps_id w fw t {T} f : + @carryterm_cps w fw t T f + = f (@carryterm w fw t). + Proof. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed. Hint Rewrite carryterm_cps_id : uncps. + Lemma eval_carryterm w fw (t:limb) (fw_nonzero:fw<>0): + eval (carryterm w fw t) = eval [t]. + Proof. + cbv [carryterm_cps carryterm Let_In]; prove_eval. + specialize (div_mod (snd t) fw fw_nonzero). + nsatz. + Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval. + Lemma carry_cps_id w fw p {T} f: + @carry_cps w fw p T f = f (carry w fw p). + Proof. cbv [carry_cps carry]; prove_id. Qed. + Hint Rewrite carry_cps_id : uncps. + Lemma eval_carry w fw p (fw_nonzero:fw<>0): + eval (carry w fw p) = eval p. + Proof. cbv [carry_cps carry]; induction p; prove_eval. Qed. + Hint Rewrite eval_carry using auto : push_basesystem_eval. + End Carries. + + Section Saturated. + Context {word_max : Z} {word_max_pos : 1 < word_max} + {add : Z -> Z -> Z * Z} + {add_correct : forall x y, fst (add x y) + word_max * snd (add x y) = x + y} + {mul : Z -> Z -> Z * Z} + {mul_correct : forall x y, fst (mul x y) + word_max * snd (mul x y) = x * y} + {end_wt:Z} {end_wt_pos : 0 < end_wt} + . + + Definition sat_multerm_cps (t t' : limb) {T} (f:list limb->T) := + dlet tt' := mul (snd t) (snd t') in + f ((fst t*fst t', runtime_fst tt') :: (fst t*fst t'*word_max, runtime_snd tt') :: nil)%list. + Definition sat_mul_cps (p q : list limb) {T} (f:list limb->T) := + flat_map_cps (fun t => @flat_map_cps _ _ (sat_multerm_cps t) q) p f. + (* TODO (jgross): kind of an interesting behavior--it infers the type arguments like this but fails to check if I leave them implicit *) + Definition sat_multerm t t' := sat_multerm_cps t t' id. + Definition sat_mul p q := sat_mul_cps p q id. + Hint Opaque sat_multerm sat_mul : uncps. + Lemma sat_multerm_cps_id t t' : forall {T} (f:list limb->T), + sat_multerm_cps t t' f = f (sat_multerm t t'). + Proof. reflexivity. Qed. Hint Rewrite sat_multerm_cps_id : uncps. + Lemma eval_map_sat_multerm_cps t q : + eval (flat_map (fun x => sat_multerm_cps t x id) q) = fst t * snd t * eval q. + Proof. + cbv [sat_multerm sat_multerm_cps Let_In runtime_fst runtime_snd]; + induction q; prove_eval; + try match goal with |- context [mul ?a ?b] => + specialize (mul_correct a b) end; + nsatz. + Qed. Hint Rewrite eval_map_sat_multerm_cps : push_basesystem_eval. + Lemma sat_mul_cps_id p q {T} f : @sat_mul_cps p q T f = f (sat_mul p q). + Proof. cbv [sat_mul_cps sat_mul]; prove_id. Qed. Hint Rewrite sat_mul_cps_id : uncps. + Lemma eval_sat_mul p q : eval (sat_mul p q) = eval p * eval q. + Proof. cbv [sat_mul_cps sat_mul]; induction p; prove_eval. Qed. + Hint Rewrite eval_sat_mul : push_basesystem_eval. + + End Saturated. + End Associational. + Hint Rewrite + @Associational.sat_mul_cps_id + @Associational.sat_multerm_cps_id + @Associational.carry_cps_id + @Associational.carryterm_cps_id + @Associational.reduce_cps_id + @Associational.split_id + @Associational.mul_cps_id : uncps. + + Module Positional. + Section Positional. + Import Associational. + Context (weight : nat -> Z) (* [weight i] is the weight of position [i] *) + (weight_0 : weight 0%nat = 1%Z) + (weight_nonzero : forall i, weight i <> 0). + + (** Converting from positional to associational *) + + Definition to_associational_cps {n:nat} (xs:tuple Z n) + {T} (f:list limb->T) := + map_cps weight (seq 0 n) + (fun r => + to_list_cps n xs (fun rr => combine_cps r rr f)). + Definition to_associational {n} xs := @to_associational_cps n xs _ id. + Definition eval {n} x := @to_associational_cps n x _ Associational.eval. + Lemma to_associational_cps_id {n} x {T} f: + @to_associational_cps n x T f = f (to_associational x). + Proof. cbv [to_associational_cps to_associational]; prove_id. Qed. + Hint Rewrite @to_associational_cps_id : uncps. + Lemma eval_to_associational {n} x : + Associational.eval (@to_associational n x) = eval x. + Proof. cbv [to_associational_cps eval to_associational]; prove_eval. Qed. + Hint Rewrite @eval_to_associational : push_basesystem_eval. + + (** Converting from associational to positional *) + + Program Definition zeros n : tuple Z n := Tuple.from_list n (List.map (fun _ => 0) (List.seq 0 n)) _. + Next Obligation. autorewrite with distr_length; reflexivity. Qed. + Lemma eval_zeros n : eval (zeros n) = 0. + Proof. + cbv [eval Associational.eval to_associational_cps zeros]; + autorewrite with uncps; rewrite Tuple.to_list_from_list. + generalize dependent (List.seq 0 n); intro xs; induction xs; simpl; nsatz. + Qed. Hint Rewrite eval_zeros : push_basesystem_eval. + + Definition add_to_nth_cps {n} i x t {T} (f:tuple Z n->T) := + @on_tuple_cps _ _ 0 (update_nth_cps i (runtime_add x)) n n t _ f. + Definition add_to_nth {n} i x t := @add_to_nth_cps n i x t _ id. + Hint Opaque add_to_nth : uncps. + Lemma add_to_nth_cps_id {n} i x xs {T} f: + @add_to_nth_cps n i x xs T f = f (add_to_nth i x xs). + Proof. + cbv [add_to_nth_cps add_to_nth]; erewrite !on_tuple_cps_correct + by (intros; autorewrite with uncps; reflexivity); prove_id. + Unshelve. + intros; subst. autorewrite with uncps push_id. distr_length. + Qed. Hint Rewrite @add_to_nth_cps_id : uncps. + Lemma eval_add_to_nth {n} (i:nat) (x:Z) (H:(i progress (apply Zminus_eq; ring_simplify) + | _ => progress autorewrite with push_basesystem_eval cancel_pair distr_length + | _ => progress rewrite <-?ListUtil.map_nth_default_always, ?map_fst_combine, ?List.firstn_all2, ?ListUtil.map_nth_default_always, ?nth_default_seq_inbouns, ?plus_O_n + end; trivial; lia. + Unshelve. + intros; subst. autorewrite with uncps push_id. distr_length. + Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval. + + Fixpoint place_cps (t:limb) (i:nat) {T} (f:nat * Z->T) := + if dec (fst t mod weight i = 0) + then f (i, let c := fst t / weight i in (c * snd t)%RT) + else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + Lemma place_cps_in_range (t:limb) (n:nat) : (fst (place_cps t n id) < S n)%nat. + Proof. induction n; simpl; break_match; simpl; omega. Qed. + Lemma weight_place_cps t i : weight (fst (place_cps t i id)) * snd (place_cps t i id) = fst t * snd t. + Proof. + induction i; cbv [id]; simpl place_cps; break_match; + autorewrite with cancel_pair; + try find_apply_lem_hyp Z_div_exact_full_2; nsatz || auto. + Qed. + Definition place t i := place_cps t i id. + Hint Opaque place : uncps. + Lemma place_cps_id t i {T} f : + @place_cps t i T f = f (place t i). + Proof. cbv [place]; induction i; prove_id. Qed. + Hint Rewrite place_cps_id : uncps. + Definition from_associational_cps n (p:list limb) {T} (f:tuple Z n->T):= + fold_right_cps (fun t st => place_cps t (pred n) (fun p=> add_to_nth_cps (fst p) (snd p) st id)) (zeros n) p f. + Definition from_associational n p := from_associational_cps n p id. + Hint Opaque from_associational : uncps. + Lemma from_associational_cps_id {n} p {T} f: + @from_associational_cps n p T f = f (from_associational n p). + Proof. cbv [from_associational_cps from_associational]; prove_id. Qed. + Hint Rewrite @from_associational_cps_id : uncps. + Lemma eval_from_associational {n} p (n_nonzero:n<>O): + eval (from_associational n p) = Associational.eval p. + Proof. + cbv [from_associational_cps from_associational]; induction p; + [|pose proof (place_cps_in_range a (pred n))]; prove_eval. + cbv [place]; rewrite weight_place_cps. nsatz. + Qed. Hint Rewrite @eval_from_associational using omega : push_basesystem_eval. + + Section Carries. + Context {modulo div : Z->Z->Z}. + Context {div_mod : forall a b:Z, b <> 0 -> + a = b * (div a b) + modulo a b}. + Definition carry_cps(index:nat) (p:list limb) {T} (f:list limb->T) := + @Associational.carry_cps modulo div (weight index) (weight (S index) / weight index) p T f. + Definition carry i p := carry_cps i p id. + Hint Opaque carry : uncps. + Lemma carry_cps_id i p {T} f: + @carry_cps i p T f = f (carry i p). + Proof. cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity. Qed. + Hint Rewrite carry_cps_id : uncps. + Lemma eval_carry i p: weight (S i) / weight i <> 0 -> + Associational.eval (carry i p) = Associational.eval p. + Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed. + Hint Rewrite @eval_carry : push_basesystem_eval. + End Carries. + End Positional. + End Positional. + Hint Rewrite + @Associational.sat_mul_cps_id + @Associational.sat_multerm_cps_id + @Associational.carry_cps_id + @Associational.carryterm_cps_id + @Associational.reduce_cps_id + @Associational.split_id + @Associational.mul_cps_id + @Positional.carry_cps_id + @Positional.from_associational_cps_id + @Positional.place_cps_id + @Positional.add_to_nth_cps_id + @Positional.to_associational_cps_id + : uncps. + Hint Rewrite + @Associational.eval_sat_mul + @Associational.eval_mul_noncps + @Positional.eval_to_associational + @Associational.eval_carry + @Associational.eval_carryterm + @Associational.eval_reduce + @Associational.eval_split_noncps + @Positional.eval_carry + @Positional.eval_from_associational + @Positional.eval_add_to_nth + using (omega || assumption) : push_basesystem_eval. +End B. + +Local Coercion Z.of_nat : nat >-> Z. +Import Coq.Lists.List.ListNotations. Local Open Scope list_scope. +Import B. + +Ltac assert_preconditions := + repeat match goal with + | |- context [Positional.from_associational_cps ?wt ?n] => + unique assert (wt 0%nat = 1) by (cbv; congruence) + | |- context [Positional.from_associational_cps ?wt ?n] => + unique assert (forall i, wt i <> 0) by (intros; apply Z.pow_nonzero; try (cbv; congruence); solve [zero_bounds]) + | |- context [Positional.from_associational_cps ?wt ?n] => + unique assert (n <> 0%nat) by (cbv; congruence) + | |- context [Positional.carry_cps?wt ?i] => + unique assert (wt (S i) / wt i <> 0) by (cbv; congruence) + end. + +Ltac op_simplify := + cbv - [runtime_add runtime_mul Let_In]; + cbv [runtime_add runtime_mul]; + repeat progress rewrite ?Z.mul_1_l, ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_0_r. + +Ltac prove_op sz x := + cbv [Tuple.tuple Tuple.tuple'] in *; + repeat match goal with p : _ * Z |- _ => destruct p end; + apply lift2_sig; + eexists; cbv zeta beta; intros; + match goal with |- Positional.eval ?wt _ = ?op (Positional.eval ?wt ?a) (Positional.eval ?wt ?b) => + transitivity (Positional.eval wt (x wt a b)) + end; + [ apply f_equal; op_simplify; reflexivity + | assert_preconditions; + progress autorewrite with uncps push_id push_basesystem_eval; + reflexivity ] +. + +Section Ops. + Context + (modulo : Z -> Z -> Z) + (div : Z -> Z -> Z) + (div_mod : forall a b : Z, b <> 0 -> + a = b * div a b + modulo a b). + Local Infix "^" := tuple : type_scope. + + Let wt := fun i : nat => 2^(25 * (i / 2) + 26 * ((i + 1) / 2)). + Let sz := 10%nat. + Let sz2 := Eval compute in ((sz * 2) - 1)%nat. + + (* shorthand for many carries in a row *) + Definition chained_carries (w : nat -> Z) (p:list B.limb) (idxs : list nat) + {T} (f:list B.limb->T) := + fold_right_cps2 (@Positional.carry_cps w modulo div) p idxs f. + + Definition addT : + { add : (Z^sz -> Z^sz -> Z^sz)%type & + forall a b : Z^sz, + let eval {n} := Positional.eval (n := n) wt in + eval (add a b) = eval a + eval b }. + Proof. + prove_op sz ( + fun wt a b => + Positional.to_associational_cps (n := sz) wt a + (fun r => Positional.to_associational_cps (n := sz) wt b + (fun r0 => Positional.from_associational_cps wt sz (r ++ r0) id + ))). + Defined. + + + Definition mulT : + {mul : (Z^sz -> Z^sz -> Z^sz2)%type & + forall a b : Z^sz, + let eval {n} := Positional.eval (n := n) wt in + eval (mul a b) = eval a * eval b }. + Proof. + let x := (eval cbv [chained_carries seq fold_right_cps2 sz2] in + (fun w a b => + Positional.to_associational_cps (n := sz) w a + (fun r => Positional.to_associational_cps (n := sz) w b + (fun r0 => Associational.mul_cps r r0 + (fun r1 => Positional.from_associational_cps w sz2 r1 + (fun r2 => Positional.to_associational_cps w r2 + (fun r3 => chained_carries w r3 (seq 0 sz2) + (fun r13 => Positional.from_associational_cps w sz2 r13 id + )))))))) in + prove_op sz x. + Time Defined. (* Finished transaction in 139.086 secs *) + +End Ops. + +Eval cbv [projT1 addT lift2_sig proj1_sig] in (projT1 addT). +Eval cbv [projT1 mulT lift2_sig proj1_sig] in + (fun m d div_mod => projT1 (mulT m d div_mod)). diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v new file mode 100644 index 0000000000..5d2a803993 --- /dev/null +++ b/src/Util/CPSUtil.v @@ -0,0 +1,244 @@ +Require Import Coq.Lists.List. Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.omega.Omega. +Require Import Crypto.Util.ListUtil Crypto.Util.Tactics. +Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple. +Local Open Scope Z_scope. + +Lemma push_id {A} (a:A) : id a = a. reflexivity. Qed. +Create HintDb push_id discriminated. Hint Rewrite @push_id : push_id. + +Lemma update_nth_id {T} i (xs:list T) : ListUtil.update_nth i id xs = xs. +Proof. + revert xs; induction i; destruct xs; simpl; solve [ trivial | congruence ]. +Qed. + +Lemma map_fst_combine {A B} (xs:list A) (ys:list B) : List.map fst (List.combine xs ys) = List.firstn (length ys) xs. +Proof. + revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ]. +Qed. + +Lemma map_snd_combine {A B} (xs:list A) (ys:list B) : List.map snd (List.combine xs ys) = List.firstn (length xs) ys. +Proof. + revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ]. +Qed. + +Lemma nth_default_seq_inbouns d s n i (H:(i < n)%nat) : + List.nth_default d (List.seq s n) i = (s+i)%nat. +Proof. + progress cbv [List.nth_default]. + rewrite ListUtil.nth_error_seq. + break_innermost_match; solve [ trivial | omega ]. +Qed. + +Lemma mod_add_mul_full a b c k m : m <> 0 -> c mod m = k mod m -> + (a + b * c) mod m = (a + b * k) mod m. +Proof. + intros; rewrite Z.add_mod, Z.mul_mod by auto. + match goal with H : _ mod _ = _ mod _ |- _ => rewrite H end. + rewrite <-Z.mul_mod, <-Z.add_mod by auto; reflexivity. +Qed. + +(* TODO +Lemma to_nat_neg : forall x, x < 0 -> Z.to_nat x = 0%nat. +Proof. destruct x; try reflexivity; intros. pose proof (Pos2Z.is_pos p). omega. Qed. + *) + +Fixpoint map_cps {A B} (g : A->B) ls + {T} (f:list B->T):= + match ls with + | nil => f nil + | a :: t => map_cps g t (fun r => f (g a :: r)) + end. +Lemma map_cps_correct {A B} g ls: forall {T} f, + @map_cps A B g ls T f = f (map g ls). +Proof. induction ls; simpl; intros; rewrite ?IHls; reflexivity. Qed. +Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps. + +Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T) := + match ls with + | nil => f nil + | (x::tl)%list => g x (fun r => flat_map_cps g tl (fun rr => f (r ++ rr))%list) + end. +Lemma flat_map_cps_correct {A B} (g:A->forall {T}, (list B->T)->T) ls : + forall {T} (f:list B->T), + (forall x T h, @g x T h = h (g x id)) -> + @flat_map_cps A B g ls T f = f (List.flat_map (fun x => g x id) ls). +Proof. + induction ls; intros; [reflexivity|]. + simpl flat_map_cps. simpl flat_map. + rewrite H; erewrite IHls by eassumption. + reflexivity. +Qed. +Hint Rewrite @flat_map_cps_correct using (intros; autorewrite with uncps; auto): uncps. + +Fixpoint from_list_default'_cps {A} (d y:A) n xs: + forall {T}, (Tuple.tuple' A n -> T) -> T:= + match n as n0 return (forall {T}, (Tuple.tuple' A n0 ->T) ->T) with + | O => fun T f => f y + | S n' => fun T f => + match xs with + | nil => from_list_default'_cps d d n' nil (fun r => f (r, y)) + | x :: xs' => from_list_default'_cps d x n' xs' (fun r => f (r, y)) + end + end. +Lemma from_list_default'_cps_correct {A} n : forall d y l {T} f, + @from_list_default'_cps A d y n l T f = f (Tuple.from_list_default' d y n l). +Proof. + induction n; intros; simpl; [reflexivity|]. + break_match; subst; apply IHn. +Qed. +Definition from_list_default_cps {A} (d:A) n (xs:list A) : + forall {T}, (Tuple.tuple A n -> T) -> T:= + match n as n0 return (forall {T}, (Tuple.tuple A n0 ->T) ->T) with + | O => fun T f => f tt + | S n' => fun T f => + match xs with + | nil => from_list_default'_cps d d n' nil f + | x :: xs' => from_list_default'_cps d x n' xs' f + end + end. +Lemma from_list_default_cps_correct {A} n : forall d l {T} f, + @from_list_default_cps A d n l T f = f (Tuple.from_list_default d n l). +Proof. + destruct n; intros; simpl; [reflexivity|]. + break_match; auto using from_list_default'_cps_correct. +Qed. +Hint Rewrite @from_list_default_cps_correct : uncps. +Fixpoint to_list'_cps {A} n + {T} (f:list A -> T) : Tuple.tuple' A n -> T := + match n as n0 return (Tuple.tuple' A n0 -> T) with + | O => fun x => f [x] + | S n' => fun (xs: Tuple.tuple' A (S n')) => + let (xs', x) := xs in + to_list'_cps n' (fun r => f (x::r)) xs' + end. +Lemma to_list'_cps_correct {A} n: forall t {T} f, + @to_list'_cps A n T f t = f (Tuple.to_list' n t). +Proof. + induction n; simpl; intros; [reflexivity|]. + destruct_head prod. apply IHn. +Qed. +Definition to_list_cps' {A} n {T} (f:list A->T) + : Tuple.tuple A n -> T := + match n as n0 return (Tuple.tuple A n0 ->T) with + | O => fun _ => f nil + | S n' => to_list'_cps n' f + end. +Definition to_list_cps {A} n t {T} f := + @to_list_cps' A n T f t. +Lemma to_list_cps_correct {A} n t {T} f : + @to_list_cps A n t T f = f (Tuple.to_list n t). +Proof. cbv [to_list_cps to_list_cps' Tuple.to_list]; break_match; auto using to_list'_cps_correct. Qed. +Hint Rewrite @to_list_cps_correct : uncps. + +Definition on_tuple_cps {A B} (d:B) (g:list A ->forall {T},(list B->T)->T) {n m} + (xs : Tuple.tuple A n) {T} (f:tuple B m ->T) := + to_list_cps n xs (fun r => g r (fun rr => from_list_default_cps d m rr f)). +Lemma on_tuple_cps_correct {A B} d (g:list A -> forall {T}, (list B->T)->T) + {n m} xs {T} f + (Hg : forall x {T} h, @g x T h = h (g x id)) : forall H, + @on_tuple_cps A B d g n m xs T f = f (@Tuple.on_tuple A B (fun x => g x id) n m H xs). +Proof. + cbv [on_tuple_cps Tuple.on_tuple]; intros. + rewrite to_list_cps_correct, Hg, from_list_default_cps_correct. + rewrite (Tuple.from_list_default_eq _ _ _ (H _ (Tuple.length_to_list _))). + reflexivity. +Qed. Hint Rewrite @on_tuple_cps_correct using (intros; autorewrite with uncps; auto): uncps. + +Fixpoint update_nth_cps {A} n (g:A->A) xs {T} (f:list A->T) := + match n with + | O => + match xs with + | [] => f [] + | x' :: xs' => f (g x' :: xs') + end + | S n' => + match xs with + | [] => f [] + | x' :: xs' => update_nth_cps n' g xs' (fun r => f (x' :: r)) + end + end. +Lemma update_nth_cps_correct {A} n g: forall xs T f, + @update_nth_cps A n g xs T f = f (update_nth n g xs). +Proof. induction n; intros; simpl; break_match; try apply IHn; reflexivity. Qed. +Hint Rewrite @update_nth_cps_correct : uncps. + +Fixpoint combine_cps {A B} (la :list A) (lb : list B) + {T} (f:list (A*B)->T) := + match la with + | nil => f nil + | a :: tla => + match lb with + | nil => f nil + | b :: tlb => combine_cps tla tlb (fun lab => f ((a,b)::lab)) + end + end. +Lemma combine_cps_correct {A B} la: forall lb {T} f, + @combine_cps A B la lb T f = f (combine la lb). +Proof. + induction la; simpl combine_cps; simpl combine; intros; + try break_match; try apply IHla; reflexivity. +Qed. +Hint Rewrite @combine_cps_correct: uncps. + +(* differs from fold_right_cps in that the functional argument `g` is also a CPS function *) +Fixpoint fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) := + match l with + | nil => f a0 + | b :: tl => fold_right_cps2 g a0 tl (fun r => g b r f) + end. +Lemma fold_right_cps2_correct {A B} g a0 l : forall {T} f, + (forall b a T h, @g b a T h = h (@g b a A id)) -> + @fold_right_cps2 A B g a0 l T f = f (List.fold_right (fun b a => @g b a A id) a0 l). +Proof. + induction l; intros; [reflexivity|]. + simpl fold_right_cps2. simpl fold_right. + rewrite H; erewrite IHl by eassumption. + rewrite H; reflexivity. +Qed. +Hint Rewrite @fold_right_cps2_correct using (intros; autorewrite with uncps; auto): uncps. + +Definition fold_right_no_starter {A} (f:A->A->A) ls : option A := + match ls with + | nil => None + | cons x tl => Some (List.fold_right f x tl) + end. +Lemma fold_right_min ls x : + x = List.fold_right Z.min x ls + \/ List.In (List.fold_right Z.min x ls) ls. +Proof. + induction ls; intros; simpl in *; try tauto. + match goal with |- context [Z.min ?x ?y] => + destruct (Z.min_spec x y) as [[? Hmin]|[? Hmin]] + end; rewrite Hmin; tauto. +Qed. +Lemma fold_right_no_starter_min ls : forall x, + fold_right_no_starter Z.min ls = Some x -> + List.In x ls. +Proof. + cbv [fold_right_no_starter]; intros; destruct ls; try discriminate. + inversion H; subst; clear H. + destruct (fold_right_min ls z); + simpl List.In; tauto. +Qed. +Fixpoint fold_right_cps {A B} (g:B->A->A) (a0:A) (l:list B) {T} (f:A->T) := + match l with + | nil => f a0 + | cons a tl => fold_right_cps g a0 tl (fun r => f (g a r)) + end. +Lemma fold_right_cps_correct {A B} g a0 l: forall {T} f, + @fold_right_cps A B g a0 l T f = f (List.fold_right g a0 l). +Proof. induction l; intros; simpl; rewrite ?IHl; auto. Qed. +Hint Rewrite @fold_right_cps_correct : uncps. + +Definition fold_right_no_starter_cps {A} g ls {T} (f:option A->T) := + match ls with + | nil => f None + | cons x tl => f (Some (List.fold_right g x tl)) + end. +Lemma fold_right_no_starter_cps_correct {A} g ls {T} f : + @fold_right_no_starter_cps A g ls T f = f (fold_right_no_starter g ls). +Proof. + cbv [fold_right_no_starter_cps fold_right_no_starter]; break_match; reflexivity. +Qed. +Hint Rewrite @fold_right_no_starter_cps_correct : uncps. diff --git a/src/Util/Prod.v b/src/Util/Prod.v index bcd9404a60..6e6c7d3c49 100644 --- a/src/Util/Prod.v +++ b/src/Util/Prod.v @@ -16,6 +16,10 @@ Local Arguments f_equal {_ _} _ {_ _} _. Scheme Equality for prod. +Definition fst_pair {A B} (a:A) (b:B) : fst (a,b) = a := eq_refl. +Definition snd_pair {A B} (a:A) (b:B) : snd (a,b) = b := eq_refl. +Create HintDb cancel_pair discriminated. Hint Rewrite @fst_pair @snd_pair : cancel_pair. + (** ** Equality for [prod] *) Section prod. (** *** Projecting an equality of a pair to equality of the first components *) diff --git a/src/Util/Sigma.v b/src/Util/Sigma.v index 57c82df682..7a1d0cacb7 100644 --- a/src/Util/Sigma.v +++ b/src/Util/Sigma.v @@ -16,6 +16,15 @@ Local Arguments f_equal {_ _} _ {_ _} _. (** ** Equality for [sigT] *) Section sigT. + (* Lift foralls out of sigT proofs and leave a sig goal *) + Definition lift2_sig {R S T} f (g:R->S) + (X : forall a b, {prod | g prod = f a b}) : + { op : T -> T -> R & forall a b, g (op a b) = f a b }. + Proof. + exists (fun a b => proj1_sig (X a b)). + exact (fun a b => proj2_sig (X a b)). + Defined. + (** *** Projecting an equality of a pair to equality of the first components *) Definition pr1_path {A} {P : A -> Type} {u v : sigT P} (p : u = v) : projT1 u = projT1 v diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v index ee280ca067..4c6e2441da 100644 --- a/src/Util/ZUtil.v +++ b/src/Util/ZUtil.v @@ -766,7 +766,6 @@ Module Z. apply Z.mod_mul, Z.pow_nonzero; omega. } Qed. - Lemma odd_mod : forall a b, (b <> 0)%Z -> Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a. Proof.