(Joint Center)Library mxtens

(* (c) Copyright Microsoft Corporation and Inria. All rights reserved. *)
Require Import ssreflect ssrfun ssrbool eqtype ssrnat seq choice fintype.
Require Import bigop ssralg matrix zmodp div.

Set Implicit Arguments.

Import GRing.Theory.
Local Open Scope nat_scope.
Local Open Scope ring_scope.

Section ExtraBigOp.

Lemma sumr_add : (R : ringType) m n (F : 'I_(m + n)R),
  \sum_(i < m + n) F i = \sum_(i < m) F (lshift _ i)
  + \sum_(i < n) F (rshift _ i).

Lemma mxtens_index_proof m n (ij : 'I_m × 'I_n) : ij.1 × n + ij.2 < m × n.

Definition mxtens_index m n ij := Ordinal (@mxtens_index_proof m n ij).

Lemma mxtens_index_proof1 m n (k : 'I_(m × n)) : k %/ n < m.
Lemma mxtens_index_proof2 m n (k : 'I_(m × n)) : k %% n < n.

Definition mxtens_unindex m n k :=
  (Ordinal (@mxtens_index_proof1 m n k), Ordinal (@mxtens_index_proof2 m n k)).

Implicit Arguments mxtens_index [[m] [n]].
Implicit Arguments mxtens_unindex [[m] [n]].

Lemma mxtens_indexK m n : cancel (@mxtens_index m n) (@mxtens_unindex m n).

Lemma mxtens_unindexK m n : cancel (@mxtens_unindex m n) (@mxtens_index m n).

CoInductive is_mxtens_index (m n : nat) : 'I_(m × n)Type :=
    IsMxtensIndex : (i : 'I_m) (j : 'I_n),
                   is_mxtens_index (mxtens_index (i, j)).

Lemma mxtens_indexP (m n : nat) (k : 'I_(m × n)) : is_mxtens_index k.

Lemma mulr_sum (R : ringType) m n (Fm : 'I_mR) (Fn : 'I_nR) :
  (\sum_(i < m) Fm i) × (\sum_(i < n) Fn i)
  = \sum_(i < m × n) ((Fm (mxtens_unindex i).1) × (Fn (mxtens_unindex i).2)).

End ExtraBigOp.

Section ExtraMx.

Lemma castmx_mul (R : ringType)
  (m m' n p p': nat) (em : m = m') (ep : p = p')
  (M : 'M[R]_(m, n)) (N : 'M[R]_(n, p)) :
  castmx (em, ep) (M ×m N) = castmx (em, erefl _) M ×m castmx (erefl _, ep) N.

Lemma mulmx_cast (R : ringType)
  (m n n' p p' : nat) (en : n' = n) (ep : p' = p)
  (M : 'M[R]_(m, n)) (N : 'M[R]_(n', p')) :
  M ×m (castmx (en, ep) N) =
  (castmx (erefl _, (esym en)) M) ×m (castmx (erefl _, ep) N).

Lemma castmx_row (R : Type) (m m' n1 n2 n1' n2' : nat)
  (eq_n1 : n1 = n1') (eq_n2 : n2 = n2') (eq_n12 : (n1 + n2 = n1' + n2')%N)
  (eq_m : m = m') (A1 : 'M[R]_(m, n1)) (A2 : 'M_(m, n2)) :
  castmx (eq_m, eq_n12) (row_mx A1 A2) =
  row_mx (castmx (eq_m, eq_n1) A1) (castmx (eq_m, eq_n2) A2).

Lemma castmx_col (R : Type) (m m' n1 n2 n1' n2' : nat)
  (eq_n1 : n1 = n1') (eq_n2 : n2 = n2') (eq_n12 : (n1 + n2 = n1' + n2')%N)
  (eq_m : m = m') (A1 : 'M[R]_(n1, m)) (A2 : 'M_(n2, m)) :
  castmx (eq_n12, eq_m) (col_mx A1 A2) =
  col_mx (castmx (eq_n1, eq_m) A1) (castmx (eq_n2, eq_m) A2).

Lemma castmx_block (R : Type) (m1 m1' m2 m2' n1 n2 n1' n2' : nat)
  (eq_m1 : m1 = m1') (eq_n1 : n1 = n1') (eq_m2 : m2 = m2') (eq_n2 : n2 = n2')
  (eq_m12 : (m1 + m2 = m1' + m2')%N) (eq_n12 : (n1 + n2 = n1' + n2')%N)
  (ul : 'M[R]_(m1, n1)) (ur : 'M[R]_(m1, n2))
  (dl : 'M[R]_(m2, n1)) (dr : 'M[R]_(m2, n2)) :
  castmx (eq_m12, eq_n12) (block_mx ul ur dl dr) =
  block_mx (castmx (eq_m1, eq_n1) ul) (castmx (eq_m1, eq_n2) ur)
  (castmx (eq_m2, eq_n1) dl) (castmx (eq_m2, eq_n2) dr).

End ExtraMx.

Section MxTens.

Variable R : ringType.

Definition tensmx {m n p q : nat}
  (A : 'M_(m, n)) (B : 'M_(p, q)) : 'M[R]_(_,_) := nosimpl
  (\matrix_(i, j) (A (mxtens_unindex i).1 (mxtens_unindex j).1
                 × B (mxtens_unindex i).2 (mxtens_unindex j).2)).

Notation "A *t B" := (tensmx A B)
  (at level 40, left associativity, format "A *t B").

Lemma tensmxE {m n p q} (A : 'M_(m, n)) (B : 'M_(p, q)) i j k l :
  (A ×t B) (mxtens_index (i, j)) (mxtens_index (k, l)) = A i k × B j l.

Lemma tens0mx {m n p q} (M : 'M[R]_(p,q)) : (0 : 'M_(m,n)) ×t M = 0.

Lemma tensmx0 {m n p q} (M : 'M[R]_(m,n)) : M ×t (0 : 'M_(p,q)) = 0.

Lemma tens_scalar_mx (m n : nat) (c : R) (M : 'M_(m,n)):
  c%:M ×t M = castmx (esym (mul1n _), esym (mul1n _)) (c *: M).

Lemma tens_scalar1mx (m n : nat) (M : 'M_(m,n)) :
  1 ×t M = castmx (esym (mul1n _), esym (mul1n _)) M.

Lemma tens_scalarN1mx (m n : nat) (M : 'M_(m,n)) :
  (-1) ×t M = castmx (esym (mul1n _), esym (mul1n _)) (-M).

Lemma trmx_tens {m n p q} (M :'M[R]_(m,n)) (N : 'M[R]_(p,q)) :
  (M ×t N)^T = M^T ×t N^T.

Lemma tens_col_mx {m n p q} (r : 'rV[R]_n)
  (M :'M[R]_(m, n)) (N : 'M[R]_(p, q)) :
  (col_mx r M) ×t N =
  castmx (esym (mulnDl _ _ _), erefl _) (col_mx (r ×t N) (M ×t N)).

Lemma tens_row_mx {m n p q} (r : 'cV[R]_m) (M :'M[R]_(m,n)) (N : 'M[R]_(p,q)) :
  (row_mx r M) ×t N =
  castmx (erefl _, esym (mulnDl _ _ _)) (row_mx (r ×t N) (M ×t N)).

Lemma tens_block_mx {m n p q}
  (ul : 'M[R]_1) (ur : 'rV[R]_n) (dl : 'cV[R]_m)
  (M :'M[R]_(m,n)) (N : 'M[R]_(p,q)) :
  (block_mx ul ur dl M) ×t N =
  castmx (esym (mulnDl _ _ _), esym (mulnDl _ _ _))
  (block_mx (ul ×t N) (ur ×t N) (dl ×t N) (M ×t N)).

Fixpoint ntensmx_rec {m n} (A : 'M_(m,n)) k : 'M_(m ^ k.+1,n ^ k.+1) :=
  if k is k'.+1 then (A ×t (ntensmx_rec A k')) else A.

Definition ntensmx {m n} (A : 'M_(m, n)) k := nosimpl
  (if k is k'.+1 return 'M[R]_(m ^ k,n ^ k) then ntensmx_rec A k' else 1).

Notation "A ^t k" := (ntensmx A k)
  (at level 39, left associativity, format "A ^t k").

Lemma ntensmx0 : {m n} (A : 'M_(m,n)) , A ^t 0 = 1.

Lemma ntensmx1 : {m n} (A : 'M_(m,n)) , A ^t 1 = A.

Lemma ntensmx2 : {m n} (A : 'M_(m,n)) , A ^t 2 = A ×t A.

Lemma ntensmxSS : {m n} (A : 'M_(m,n)) k, A ^t k.+2 = A ×t A ^t k.+1.

Definition ntensmxS := (@ntensmx1, @ntensmx2, @ntensmxSS).

End MxTens.

Notation "A *t B" := (tensmx A B)
  (at level 40, left associativity, format "A *t B").

Notation "A ^t k" := (ntensmx A k)
  (at level 39, left associativity, format "A ^t k").

Section MapMx.
Variables (aR rR : ringType).
Hypothesis f : {rmorphism aRrR}.
Local Notation "A ^f" := (map_mx f A) : ring_scope.

Variables m n p q: nat.
Implicit Type A : 'M[aR]_(m, n).
Implicit Type B : 'M[aR]_(p, q).

Lemma map_mxT A B : (A ×t B)^f = A^f ×t B^f :> 'M_(m×p, n×q).

End MapMx.

Section Misc.

Lemma tensmx_mul (R : comRingType) m n p q r s
  (A : 'M[R]_(m,n)) (B : 'M[R]_(p,q)) (C : 'M[R]_(n, r)) (D : 'M[R]_(q, s)) :
  (A ×t B) ×m (C ×t D) = (A ×m C) ×t (B ×m D).

Todo : move to div ?
Lemma eq_addl_mul q q' m m' d : m < dm' < d
  (q × d + m == q' × d + m')%N = ((q, m) == (q', m')).

Lemma tensmx_unit (R : fieldType) m n (A : 'M[R]_m%N) (B : 'M[R]_n%N) :
  m != 0%Nn != 0%NA \in unitmxB \in unitmx(A ×t B) \in unitmx.

Lemma tens_mx_scalar : (R : comRingType)
  (m n : nat) (c : R) (M : 'M[R]_(m,n)),
  M ×t c%:M = castmx (esym (muln1 _), esym (muln1 _)) (c *: M).

Lemma tensmx_decr : (R : comRingType) m n (M :'M[R]_m) (N : 'M[R]_n),
  M ×t N = (M ×t 1%:M) ×m (1%:M ×t N).

Lemma tensmx_decl : (R : comRingType) m n (M :'M[R]_m) (N : 'M[R]_n),
  M ×t N = (1%:M ×t N) ×m (M ×t 1%:M).

End Misc.