(Joint Center)Library multinom

(* (c) Copyright Microsoft Corporation and Inria. All rights reserved. *)
Require Import ssreflect ssrfun ssrbool eqtype ssrnat seq choice fintype.
Require Import tuple finfun bigop ssralg poly.

Require Import multi.

Require Import generic_quotient.

Set Implicit Arguments.

Local Open Scope ring_scope.
Local Open Scope quotient_scope.
Import GRing.Theory.

Module Multinomial.

Section Multinomial.

Variable X : countType.

Section MultinomialRing.

Variable R : ringType.

Inductive multi_term :=
| Coef of R
| Var of X
| Sum of multi_term & multi_term
| Prod of multi_term & multi_term.

Fixpoint eqm m m':=
  match m with
    | Coef xif m' is Coef y then x == y else false
    | Var nif m' is Var n' then n == n' else false
    | Sum p qif m' is Sum p' q' then (eqm p p') && (eqm q q') else false
    | Prod p qif m' is Prod p' q' then (eqm p p') && (eqm q q') else false

Lemma eqm_eq : Equality.axiom eqm.

Definition multi_term_eqMixin := EqMixin eqm_eq.
Canonical multi_term_eqType :=
  Eval hnf in EqType multi_term multi_term_eqMixin.

Inductive multi_skel :=
| CoefS of nat
| VarS of nat
| SumS of multi_skel & multi_skel
| ProdS of multi_skel & multi_skel.

Fixpoint encode_multi_skel m :=
  match m with
    | CoefS i[::0%N; i]
    | VarS n[::1%N; n]
    | SumS p q ⇒ 2%N::(encode_multi_skel p)++(encode_multi_skel q)
    | ProdS p q ⇒ 3%N::(encode_multi_skel p)++(encode_multi_skel q)

Fixpoint decode_multi_skel_rec s :=
  match s with
    | 0%N::i::r(CoefS i)::(decode_multi_skel_rec r)
    | 1%N::n::r(VarS n)::(decode_multi_skel_rec r)
    | 2%N::rmatch (decode_multi_skel_rec r) with
                  | p::q::r'(SumS p q)::r'
                  | _[::]
    | 3%N::rmatch (decode_multi_skel_rec r) with
                  | p::q::r'(ProdS p q)::r'
                  | _[::]
    | _[::]

Definition decode_multi_skel s :=
  if (decode_multi_skel_rec s) is [::m] then Some m else None.

Lemma code_multi_skel_recK :
s m, decode_multi_skel_rec ((encode_multi_skel m) ++ s)
  = m::(decode_multi_skel_rec s).

Lemma code_multi_skelK : pcancel encode_multi_skel decode_multi_skel.

Definition multi_skel_countMixin := PcanCountMixin code_multi_skelK.

Definition multi_skel_choiceMixin :=
  CountChoiceMixin multi_skel_countMixin.
Definition multi_skel_eqMixin :=
  Countable.EqMixin multi_skel_countMixin.

Canonical multi_skel_eqType :=
  EqType multi_skel multi_skel_eqMixin.
Canonical multi_skel_choiceType :=
  ChoiceType multi_skel multi_skel_choiceMixin.
Canonical multi_skel_countType :=
  CountType multi_skel multi_skel_countMixin.

Fixpoint encode_multi_term_rec (s:seq R) (m:multi_term):=
  match m with
    | Coef x((CoefS (size s)), (rcons s x))
    | Var x((VarS (pickle x)), s)
    | Sum p q
      let: (p',s') := encode_multi_term_rec s p in
      let: (q',s'') := encode_multi_term_rec s' q in
        ((SumS p' q'), s'')
    | Prod p q
      let: (p',s') := encode_multi_term_rec s p in
      let: (q',s'') := encode_multi_term_rec s' q in
        ((ProdS p' q'), s'')

Definition encode_multi_term (m:multi_term) : multi_skel × (seq R):=
  encode_multi_term_rec [::] m.

Definition mkVar n : multi_term := odflt (Coef 0) (omap Var (pickle_inv _ n)).

Fixpoint decode_multi_term_rec (s:seq R) (m : multi_skel) :=
    match m with
    | CoefS iCoef (nth (0:R) s i)
    | VarS nmkVar n
    | SumS p q
      Sum (decode_multi_term_rec s p) (decode_multi_term_rec s q)
    | ProdS p q
      Prod (decode_multi_term_rec s p) (decode_multi_term_rec s q)

Definition decode_multi_term (ms : multi_skel × (seq R)) :=
  decode_multi_term_rec ms.2 ms.1.

Lemma encode_multi_term_ds : s m,
   s ++ (snd (encode_multi_term m)) = (snd (encode_multi_term_rec s m)).

Lemma code_multi_term_recK : s ds m,
   ((encode_multi_term_rec s m).2 ++ ds)
   (encode_multi_term_rec s m).1 = m.

Lemma code_multi_termK : cancel encode_multi_term decode_multi_term.

Definition multi_term_choiceMixin :=
  @CanChoiceMixin _ multi_term _ _ code_multi_termK.
Canonical multi_term_choiceType :=
  ChoiceType multi_term multi_term_choiceMixin.

Fixpoint nbvar_term t :=
  match t with
    | Coef _ ⇒ 0%N
    | Var x(pickle x).+1
    | Sum u vmaxn (nbvar_term u) (nbvar_term v)
    | Prod u vmaxn (nbvar_term u) (nbvar_term v)

Fixpoint multi n := if n is n'.+1 then [ringType of {poly multi n'}] else R.

Fixpoint inject n m (p : multi n) {struct m} : multi (m + n) :=
  if m is m'.+1 return multi (m + n) then (inject m' p)%:P else p.

Lemma inject_inj : i m, injective (@inject i m).

Lemma inject0 : i m, @inject i m 0 = 0.

Lemma inject_eq0 : i m p, (@inject i m p == 0) = (p == 0).

Lemma size_inject : i m p, size (@inject i m.+1 p) = (p != 0 : nat).

Definition cast_multi i m n Emn : multi imulti n :=
  let: erefl in _ = n' := Emn return _multi n' in inject m.

Definition multi_var n (i : 'I_n) := cast_multi (subnK (valP i)) 'X.

Notation "'X_ i" := (multi_var i).

Lemma inject_is_rmorphism : m n, rmorphism (@inject n m).
Canonical inject_rmorphism m n := RMorphism (inject_is_rmorphism m n).

Lemma cast_multi_is_rmorphism i m n Enm : rmorphism (@cast_multi i m n Enm).
Canonical cast_multi_rmorphism i m n e :=
  RMorphism (@cast_multi_is_rmorphism i m n e).

Definition multiC n : Rmulti n := cast_multi (addn0 n).
Lemma multiC_is_rmorphism n : rmorphism (multiC n).
Canonical multiC_rmorphism n := RMorphism (multiC_is_rmorphism n).

Fixpoint interp n m : multi n :=
  match m with
    | Coef xmultiC n x
    | Var xlet i := pickle x in
      (if i < n as b return (i < n) = bmulti n
        then fun iltncast_multi (subnK iltn) 'X_(Ordinal (leqnn i.+1))
        else fun _ ⇒ 0) (refl_equal (i < n))
    | Sum p qinterp n p + interp n q
    | Prod p qinterp n p × interp n q

Lemma cast_multi_inj n i i' n' (m1 m2 : multi n)
  (p1: (i + n)%N=n') (p2: (i' + n)%N=n') :
  cast_multi p1 m1 == cast_multi p2 m2 = (m1 == m2).

Lemma Emj_Enk i j k m n :
   (Emj : (m + i)%N = j) (Enk : (n + j)%N = k), (n + m + i)%N = k.

Lemma cast_multi_id n (e : (0 + n)%N = n) m : cast_multi e m = m.

Lemma cast_multiS
  n i n' (m: multi n) (p: (i+n)%N = n') (pS: ((i.+1)+n)%N = n'.+1) :
  (cast_multi pS m) = (cast_multi p m)%:P.

Lemma injectnm_cast_multi i m n p :
  inject (n + m)%N p =
  ((@cast_multi (m + i)%N n ((n + m) + i)%N (addnA _ _ _)) \o (inject m)) p.

Lemma cast_multi_add i j k m n Emj Enk p :
  @cast_multi j n k Enk (@cast_multi i m j Emj p) =
  @cast_multi i (n + m)%N k (Emj_Enk Emj Enk) p.

Lemma interp_cast_multi n n' m (nltn' : n n') :
  nbvar_term m ninterp n' m = cast_multi (subnK nltn') (interp n m).

Fixpoint reify n (m : multi n) := match n return multi n -> multi_term with | 0%N => fun m => Coef m | n'.+1 => fun m => let: Polynomial s _:= m in foldr (fun c p => Sum (reify c) (Prod (mkVar n') p)) (Coef 0) s end m.
Lemma nbvar_term_reify : forall n (m : multi n), nbvar_term (reify m) <= n. Proof. elim=> // n Hn /= [p]; elim: p => //; set f := foldr _ _. move=> a s /= Hs ls; rewrite !geq_max (leqW (Hn _)). rewrite Hs ?andbT; last by case: s ls {Hs}; rewrite //= oner_eq0. rewrite /mkVar; case hx: pickle_inv => [x| ] //=. by have ->: pickle x = n by rewrite -(@pickle_invK X n) hx. Qed.
Lemma interp_reify : forall n (m : multi n), interp n (reify m) = m. Proof. elim=> [|n Hn] m /=; first by rewrite /multiC cast_multi_id. elim: m => p i. apply: val_inj=> /=. set f := foldr _ _. elim: p i=> [#_| ]. by rewrite /= [#_ 0]rmorph0 val_insubd /= eqxx. move=> t s Hs. rewrite [f _ ]/= [interp _ _ ]/=. rewrite (interp_cast_multi (leqnSn n)); last by rewrite nbvar_term_reify. rewrite [last _ _ ]/= => ls. move: (subnK (leqnSn n)); rewrite subSnn => En. have ->: En = (erefl (n.+1)); first exact: nat_irrelevance. rewrite [cast_multi _ _ ]/=. rewrite Hn. move: (refl_equal (n < n.+1)). rewrite /mkVar. case: rewrite {2} [n < n.+1]ltnSn=> e. rewrite /multi_var. move:(subnK _) (subnK _). rewrite [val _ ]/= subnn=> subnK subnK'. have ->: subnK = erefl (n.+1); first exact: nat_irrelevance. have ->: subnK' = erefl (n.+1); first exact: nat_irrelevance. rewrite ! [cast_multi _ _ ]/= addrC mulrC -cons_poly_def polyseq_cons Hs. move=> {Hs En e subnK' subnK}. case:s ls; last by move=> a s /=. rewrite /=. move=> tn0. by rewrite val_insubd tn0. move=> {Hs En e subnK' subnK}. case:s ls; first by rewrite GRing.oner_neq0. by move=> a s ->. Qed.
Lemma interp_reify_cast_multi : forall n n' (ltnn' : n <= n') (m : multi R n) , interp n' (reify m) = cast_multi (subnK ltnn') m. Proof. move=> n n' ltnn' m. rewrite (interp_cast_multi ltnn'); last by rewrite nbvar_term_reify. by apply/eqP; rewrite cast_multi_inj interp_reify. Qed.

Definition equivm m1 m2 := let n := maxn (nbvar_term m1) (nbvar_term m2) in
                             interp n m1 == interp n m2.

Lemma interp_gtn n m1 m2 : maxn (nbvar_term m1) (nbvar_term m2) n
                           equivm m1 m2 = (interp n m1 == interp n m2).

Lemma equivm_refl : reflexive equivm.

Lemma equivm_sym : symmetric equivm.

Lemma equivm_trans : transitive equivm.

Canonical equivm_equivRel := EquivRel equivm
  equivm_refl equivm_sym equivm_trans.

Definition multinom := {eq_quot equivm}.
Definition multinom_of of phant X & phant R := multinom.

Local Notation "{ 'multinom' R }" := (multinom_of (Phant X) (Phant R))
   (at level 0, format "{ 'multinom' R }").
Canonical multinom_quotType := [quotType of multinom].
Canonical multinom_eqType := [eqType of multinom].
Canonical multinom_eqQuotType := [eqQuotType equivm of multinom].
Canonical multinom_choiceType := [choiceType of multinom].
Canonical multinom_of_quotType := [quotType of {multinom R}].
Canonical multinom_of_eqType := [eqType of {multinom R}].
Canonical multinom_of_eqQuotType := [eqQuotType equivm of {multinom R}].
Canonical multinom_of_choiceType := [choiceType of {multinom R}].

Lemma eqm_interp n m1 m2 : maxn (nbvar_term m1) (nbvar_term m2) n
         (interp n m1 == interp n m2) = (m1 == m2 %[mod {multinom R}]).

Definition cstm := lift_embed {multinom R} Coef.
Notation "c %:M" := (cstm c) (at level 2, format "c %:M").
Canonical pi_cstm_morph := PiEmbed cstm.

Definition varm := lift_embed {multinom R} Var.
Notation "n %:X" := (varm n) (at level 2, format "n %:X").
Canonical pi_varm_morph := PiEmbed varm.

Definition addm := lift_op2 {multinom R} Sum.
Lemma pi_addm : {morph \pi : x y / Sum x y >-> addm x y}.
Canonical pi_addm_morph := PiMorph2 pi_addm.

Definition Opp := Prod (Coef (-1)).
Definition oppm := lift_op1 {multinom R} Opp.
Lemma pi_oppm : {morph \pi : x / Opp x >-> oppm x}.
Canonical pi_oppm_morph := PiMorph1 pi_oppm.

Definition mulm := lift_op2 {multinom R} Prod.
Lemma pi_mulm : {morph \pi : x y / Prod x y >-> mulm x y}.
Canonical pi_mulm_morph := PiMorph2 pi_mulm.

Lemma addmA : associative addm.

Lemma addmC : commutative addm.

Lemma add0m : left_id 0%:M addm.

Lemma addmN : left_inverse 0%:M oppm addm.

Definition multinom_zmodMixin := ZmodMixin addmA addmC add0m addmN.
Canonical multinom_zmodType := ZmodType multinom multinom_zmodMixin.
Canonical multinom_of_zmodType := ZmodType {multinom R} multinom_zmodMixin.

Lemma mulmA : associative mulm.

Lemma mul1m : left_id 1%:M mulm.

Lemma mulm1 : right_id 1%:M mulm.

Lemma mulm_addl : left_distributive mulm addm.

Lemma mulm_addr : right_distributive mulm addm.

Lemma nonzero1m : 1%:M != 0%:M.

Definition multinom_ringMixin := RingMixin mulmA mul1m mulm1 mulm_addl mulm_addr nonzero1m.
Canonical multinom_ringType := RingType multinom multinom_ringMixin.
Canonical multinom_of_ringType := RingType {multinom R} multinom_ringMixin.

End MultinomialRing.

Notation "{ 'multinom' R }" := (@multinom_of _ (Phant X) (Phant R))
   (at level 0, format "{ 'multinom' R }").

Notation "c %:M" := (cstm c) (at level 2, format "c %:M").
Notation "n %:X" := (varm n) (at level 2, format "n %:X").

End Multinomial.
End Multinomial.