(Joint Center)Library matrix

(* (c) Copyright Microsoft Corporation and Inria. All rights reserved. *)
Require Import ssreflect ssrbool ssrfun eqtype ssrnat seq div choice fintype.
Require Import finfun bigop prime binomial ssralg finset fingroup finalg.
Require Import perm zmodp.

Basic concrete linear algebra : definition of type for matrices, and all basic matrix operations including determinant, trace and support for block decomposition. Matrices are represented by a row-major list of their coefficients but this implementation is hidden by three levels of wrappers (Matrix/Finfun/Tuple) so the matrix type should be treated as abstract and handled using only the operations described below: 'M[R]#_(m, n) == the type of m rows by n columns matrices with 'M_(m, n) coefficients in R; the [R] is optional and is usually omitted. 'M[R]#_n, 'M_n == the type of n x n square matrices. 'rV[R]#_n, 'rV_n == the type of 1 x n row vectors. 'cV[R]#_n, 'cV_n == the type of n x 1 column vectors. \matrix_(i < m, j < n) Expr(i, j) == the m x n matrix with general coefficient Expr(i, j), with i : 'I_m and j : 'I_n. the < m bound can be omitted if it is equal to n, though usually both bounds are omitted as they can be inferred from the context. \row_(j < n) Expr(j), \col_(i < m) Expr(i) the row / column vectors with general term Expr; the parentheses can be omitted along with the bound. \matrix_(i < m) RowExpr(i) == the m x n matrix with row i given by RowExpr(i) : 'rV_n. A i j == the coefficient of matrix A : 'M_(m, n) in column j of row i, where i : 'I_m, and j : 'I_n (via the coercion fun_of_matrix : matrix >-> Funclass). const_mx a == the constant matrix whose entries are all a (dimensions should be determined by context). map_mx f A == the pointwise image of A by f, i.e., the matrix Af congruent to A with Af i j = f (A i j) for all i and j. A^T == the matrix transpose of A. row i A == the i'th row of A (this is a row vector). col j A == the j'th column of A (a column vector). row' i A == A with the i'th row spliced out. col' i A == A with the j'th column spliced out. xrow i1 i2 A == A with rows i1 and i2 interchanged. xcol j1 j2 A == A with columns j1 and j2 interchanged. row_perm s A == A : 'M_(m, n) with rows permuted by s : 'S_m. col_perm s A == A : 'M_(m, n) with columns permuted by s : 'S_n. row_mx Al Ar == the row block matrix <Al Ar> obtained by contatenating two matrices Al and Ar of the same height. col_mx Au Ad == the column block matrix / Au \ (Au and Ad must have the same width). \ Ad / block_mx Aul Aur Adl Adr == the block matrix / Aul Aur \ \ Adl Adr / [l|r]submx A == the left/right submatrices of a row block matrix A. Note that the type of A, 'M_(m, n1 + n2) indicates how A should be decomposed. [u|d]submx A == the up/down submatrices of a column block matrix A. [u|d] [l|r]submx A == the upper left, etc submatrices of a block matrix A. castmx eq_mn A == A : 'M_(m, n) cast to 'M_(m', n') using the equation pair eq_mn : (m = m') * (n = n'). This is the usual workaround for the syntactic limitations of dependent types in Coq, and can be used to introduce a block decomposition. It simplifies to A when eq_mn is the pair (erefl m, erefl n) (using rewrite /castmx /=). conform_mx B A == A if A and B have the same dimensions, else B. mxvec A == a row vector of width m * n holding all the entries of the m x n matrix A. mxvec_index i j == the index of A i j in mxvec A. vec_mx v == the inverse of mxvec, reshaping a vector of width m * n back into into an m x n rectangular matrix. In 'M[R]#_(m, n), R can be any type, but 'M[R]#_(m, n) inherits the eqType, choiceType, countType, finType, zmodType structures of R; 'M[R]#_(m, n) also has a natural lmodType R structure when R has a ringType structure. Because the type of matrices specifies their dimension, only non-trivial square matrices (of type 'M[R]#_n.+1) can inherit the ring structure of R; indeed they then have an algebra structure (lalgType R, or algType R if R is a comRingType, or even unitAlgType if R is a comUnitRingType). We thus provide separate syntax for the general matrix multiplication, and other operations for matrices over a ringType R: A *m B == the matrix product of A and B; the width of A must be equal to the height of B. a%:M == the scalar matrix with a's on the main diagonal; in particular 1%:M denotes the identity matrix, and is is equal to 1%R when n is of the form n'.+1 (e.g., n >= 1). is_scalar_mx A <=> A is a scalar matrix (A = a%:M for some A). diag_mx d == the diagonal matrix whose main diagonal is d : 'rV_n. delta_mx i j == the matrix with a 1 in row i, column j and 0 elsewhere. pid_mx r == the partial identity matrix with 1s only on the r first coefficients of the main diagonal; the dimensions of pid_mx r are determined by the context, and pid_mx r can be rectangular. copid_mx r == the complement to 1%:M of pid_mx r: a square diagonal matrix with 1s on all but the first r coefficients on its main diagonal. perm_mx s == the n x n permutation matrix for s : 'S_n. tperm_mx i1 i2 == the permutation matrix that exchanges i1 i2 : 'I_n. is_perm_mx A == A is a permutation matrix. lift0_mx A == the 1 + n square matrix block_mx 1 0 0 A when A : 'M_n. \tr A == the trace of a square matrix A. \det A == the determinant of A, using the Leibnitz formula. cofactor i j A == the i, j cofactor of A (the signed i, j minor of A), \adj A == the adjugate matrix of A (\adj A i j = cofactor j i A). A \in unitmx == A is invertible (R must be a comUnitRingType). invmx A == the inverse matrix of A if A \in unitmx A, otherwise A. The following operations provide a correspondance between linear functions and matrices: lin1_mx f == the m x n matrix that emulates via right product a (linear) function f : 'rV_m -> 'rV_n on ROW VECTORS lin_mx f == the (m1 * n1) x (m2 * n2) matrix that emulates, via the right multiplication on the mxvec encodings, a linear function f : 'M_(m1, n1) -> 'M_(m2, n2) lin_mul_row u := lin1_mx (mulmx u \o vec_mx) (applies a row-encoded function to the row-vector u). mulmx A == partially applied matrix multiplication (mulmx A B is displayed as A *m B), with, for A : 'M_(m, n), a canonical {linear 'M_(n, p) -> 'M(m, p}} structure. mulmxr A == self-simplifying right-hand matrix multiplication, i.e., mulmxr A B simplifies to B *m A, with, for A : 'M_(n, p), a canonical {linear 'M_(m, n) -> 'M(m, p}} structure. lin_mulmx A := lin_mx (mulmx A). lin_mulmxr A := lin_mx (mulmxr A). We also extend any finType structure of R to 'M[R]#_(m, n), and define: {'GL_n[R]} == the finGroupType of units of 'M[R]#_n.-1.+1. 'GL_n[R] == the general linear group of all matrices in {'GL_n(R)}. 'GL_n(p) == 'GL_n['F_p], the general linear group of a prime field. GLval u == the coercion of u : {'GL_n(R)} to a matrix. In addition to the lemmas relevant to these definitions, this file also proves several classic results, including :
  • The determinant is a multilinear alternate form.
  • The Laplace determinant expansion formulas: expand_det_ [row|col].
  • The Cramer rule : mul_mx_adj & mul_adj_mx.
Finally, as an example of the use of block products, we program and prove the correctness of a classical linear algebra algorithm: cormenLUP A == the triangular decomposition (L, U, P) of a nontrivial square matrix A into a lower triagular matrix L with 1s on the main diagonal, an upper matrix U, and a permutation matrix P, such that P * A = L * U. This is example only; we use a different, more precise algorithm to develop the theory of matrix ranks and row spaces in mxalgebra.v

Set Implicit Arguments.

Import GroupScope.
Import GRing.Theory.
Open Local Scope ring_scope.

Reserved Notation "''M_' n" (at level 8, n at level 2, format "''M_' n").
Reserved Notation "''rV_' n" (at level 8, n at level 2, format "''rV_' n").
Reserved Notation "''cV_' n" (at level 8, n at level 2, format "''cV_' n").
Reserved Notation "''M_' ( n )" (at level 8, only parsing).
Reserved Notation "''M_' ( m , n )" (at level 8, format "''M_' ( m , n )").
Reserved Notation "''M[' R ]_ n" (at level 8, n at level 2, only parsing).
Reserved Notation "''rV[' R ]_ n" (at level 8, n at level 2, only parsing).
Reserved Notation "''cV[' R ]_ n" (at level 8, n at level 2, only parsing).
Reserved Notation "''M[' R ]_ ( n )" (at level 8, only parsing).
Reserved Notation "''M[' R ]_ ( m , n )" (at level 8, only parsing).

Reserved Notation "\matrix_ i E"
  (at level 36, E at level 36, i at level 2,
   format "\matrix_ i E").
Reserved Notation "\matrix_ ( i < n ) E"
  (at level 36, E at level 36, i, n at level 50, only parsing).
Reserved Notation "\matrix_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\matrix_ ( i , j ) E").
Reserved Notation "\matrix[ k ]_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\matrix[ k ]_ ( i , j ) E").
Reserved Notation "\matrix_ ( i < m , j < n ) E"
  (at level 36, E at level 36, i, m, j, n at level 50, only parsing).
Reserved Notation "\matrix_ ( i , j < n ) E"
  (at level 36, E at level 36, i, j, n at level 50, only parsing).
Reserved Notation "\row_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\row_ j E").
Reserved Notation "\row_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50, only parsing).
Reserved Notation "\col_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\col_ j E").
Reserved Notation "\col_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50, only parsing).

Reserved Notation "x %:M" (at level 8, format "x %:M").
Reserved Notation "A *m B" (at level 40, left associativity, format "A *m B").
Reserved Notation "A ^T" (at level 8, format "A ^T").
Reserved Notation "\tr A" (at level 10, A at level 8, format "\tr A").
Reserved Notation "\det A" (at level 10, A at level 8, format "\det A").
Reserved Notation "\adj A" (at level 10, A at level 8, format "\adj A").

Notation Local simp := (Monoid.Theory.simpm, oppr0).

Type Definition*********************************

Section MatrixDef.

Variable R : Type.
Variables m n : nat.

Basic linear algebra (matrices). We use dependent types (ordinals) for the indices so that ranges are mostly inferred automatically

Inductive matrix : predArgType := Matrix of {ffun 'I_m × 'I_nR}.

Definition mx_val A := let: Matrix g := A in g.

Canonical matrix_subType := Eval hnf in [newType for mx_val].

Fact matrix_key : unit.
Definition matrix_of_fun_def F := Matrix [ffun ij F ij.1 ij.2].
Definition matrix_of_fun k := locked_with k matrix_of_fun_def.
Canonical matrix_unlockable k := [unlockable fun matrix_of_fun k].

Definition fun_of_matrix A (i : 'I_m) (j : 'I_n) := mx_val A (i, j).

Coercion fun_of_matrix : matrix >-> Funclass.

Lemma mxE k F : matrix_of_fun k F =2 F.

Lemma matrixP (A B : matrix) : A =2 B A = B.

End MatrixDef.


Notation "''M[' R ]_ ( m , n )" := (matrix R m n) (only parsing): type_scope.
Notation "''rV[' R ]_ n" := 'M[R]_(1, n) (only parsing) : type_scope.
Notation "''cV[' R ]_ n" := 'M[R]_(n, 1) (only parsing) : type_scope.
Notation "''M[' R ]_ n" := 'M[R]_(n, n) (only parsing) : type_scope.
Notation "''M[' R ]_ ( n )" := 'M[R]_n (only parsing) : type_scope.
Notation "''M_' ( m , n )" := 'M[_]_(m, n) : type_scope.
Notation "''rV_' n" := 'M_(1, n) : type_scope.
Notation "''cV_' n" := 'M_(n, 1) : type_scope.
Notation "''M_' n" := 'M_(n, n) : type_scope.
Notation "''M_' ( n )" := 'M_n (only parsing) : type_scope.

Notation "\matrix[ k ]_ ( i , j ) E" := (matrix_of_fun k (fun i jE))
  (at level 36, E at level 36, i, j at level 50): ring_scope.

Notation "\matrix_ ( i < m , j < n ) E" :=
  (@matrix_of_fun _ m n matrix_key (fun i jE)) (only parsing) : ring_scope.

Notation "\matrix_ ( i , j < n ) E" :=
  (\matrix_(i < n, j < n) E) (only parsing) : ring_scope.

Notation "\matrix_ ( i , j ) E" := (\matrix_(i < _, j < _) E) : ring_scope.

Notation "\matrix_ ( i < m ) E" :=
  (\matrix_(i < m, j < _) @fun_of_matrix _ 1 _ E 0 j)
  (only parsing) : ring_scope.
Notation "\matrix_ i E" := (\matrix_(i < _) E) : ring_scope.

Notation "\col_ ( i < n ) E" := (@matrix_of_fun _ n 1 matrix_key (fun i _E))
  (only parsing) : ring_scope.
Notation "\col_ i E" := (\col_(i < _) E) : ring_scope.

Notation "\row_ ( j < n ) E" := (@matrix_of_fun _ 1 n matrix_key (fun _ jE))
  (only parsing) : ring_scope.
Notation "\row_ j E" := (\row_(j < _) E) : ring_scope.

Definition matrix_eqMixin (R : eqType) m n :=
  Eval hnf in [eqMixin of 'M[R]_(m, n) by <:].
Canonical matrix_eqType (R : eqType) m n:=
  Eval hnf in EqType 'M[R]_(m, n) (matrix_eqMixin R m n).
Definition matrix_choiceMixin (R : choiceType) m n :=
  [choiceMixin of 'M[R]_(m, n) by <:].
Canonical matrix_choiceType (R : choiceType) m n :=
  Eval hnf in ChoiceType 'M[R]_(m, n) (matrix_choiceMixin R m n).
Definition matrix_countMixin (R : countType) m n :=
  [countMixin of 'M[R]_(m, n) by <:].
Canonical matrix_countType (R : countType) m n :=
  Eval hnf in CountType 'M[R]_(m, n) (matrix_countMixin R m n).
Canonical matrix_subCountType (R : countType) m n :=
  Eval hnf in [subCountType of 'M[R]_(m, n)].
Definition matrix_finMixin (R : finType) m n :=
  [finMixin of 'M[R]_(m, n) by <:].
Canonical matrix_finType (R : finType) m n :=
  Eval hnf in FinType 'M[R]_(m, n) (matrix_finMixin R m n).
Canonical matrix_subFinType (R : finType) m n :=
  Eval hnf in [subFinType of 'M[R]_(m, n)].

Lemma card_matrix (F : finType) m n : (#|{: 'M[F]_(m, n)}| = #|F| ^ (m × n))%N.

Matrix structural operations (transpose, permutation, blocks) *******

Section MatrixStructural.

Variable R : Type.

Constant matrix
Fact const_mx_key : unit.
Definition const_mx m n a : 'M[R]_(m, n) := \matrix[const_mx_key]_(i, j) a.
Implicit Arguments const_mx [[m] [n]].

Section FixedDim.
Definitions and properties for which we can work with fixed dimensions.

Variables m n : nat.
Implicit Type A : 'M[R]_(m, n).

Reshape a matrix, to accomodate the block functions for instance.
Definition castmx m' n' (eq_mn : (m = m') × (n = n')) A : 'M_(m', n') :=
  let: erefl in _ = m' := eq_mn.1 return 'M_(m', n') in
  let: erefl in _ = n' := eq_mn.2 return 'M_(m, n') in A.

Definition conform_mx m' n' B A :=
  match m =P m', n =P n' with
  | ReflectT eq_m, ReflectT eq_ncastmx (eq_m, eq_n) A
  | _, _B
  end.

Transpose a matrix
Fact trmx_key : unit.
Definition trmx A := \matrix[trmx_key]_(i, j) A j i.

Permute a matrix vertically (rows) or horizontally (columns)
Fact row_perm_key : unit.
Definition row_perm (s : 'S_m) A := \matrix[row_perm_key]_(i, j) A (s i) j.
Fact col_perm_key : unit.
Definition col_perm (s : 'S_n) A := \matrix[col_perm_key]_(i, j) A i (s j).

Exchange two rows/columns of a matrix
Definition xrow i1 i2 := row_perm (tperm i1 i2).
Definition xcol j1 j2 := col_perm (tperm j1 j2).

Row/Column sub matrices of a matrix
Definition row i0 A := \row_j A i0 j.
Definition col j0 A := \col_i A i j0.

Removing a row/column from a matrix
Definition row' i0 A := \matrix_(i, j) A (lift i0 i) j.
Definition col' j0 A := \matrix_(i, j) A i (lift j0 j).

Lemma castmx_const m' n' (eq_mn : (m = m') × (n = n')) a :
  castmx eq_mn (const_mx a) = const_mx a.

Lemma trmx_const a : trmx (const_mx a) = const_mx a.

Lemma row_perm_const s a : row_perm s (const_mx a) = const_mx a.

Lemma col_perm_const s a : col_perm s (const_mx a) = const_mx a.

Lemma xrow_const i1 i2 a : xrow i1 i2 (const_mx a) = const_mx a.

Lemma xcol_const j1 j2 a : xcol j1 j2 (const_mx a) = const_mx a.

Lemma rowP (u v : 'rV[R]_n) : u 0 =1 v 0 u = v.

Lemma rowK u_ i0 : row i0 (\matrix_i u_ i) = u_ i0.

Lemma row_matrixP A B : ( i, row i A = row i B) A = B.

Lemma colP (u v : 'cV[R]_m) : u^~ 0 =1 v^~ 0 u = v.

Lemma row_const i0 a : row i0 (const_mx a) = const_mx a.

Lemma col_const j0 a : col j0 (const_mx a) = const_mx a.

Lemma row'_const i0 a : row' i0 (const_mx a) = const_mx a.

Lemma col'_const j0 a : col' j0 (const_mx a) = const_mx a.

Lemma col_perm1 A : col_perm 1 A = A.

Lemma row_perm1 A : row_perm 1 A = A.

Lemma col_permM s t A : col_perm (s × t) A = col_perm s (col_perm t A).

Lemma row_permM s t A : row_perm (s × t) A = row_perm s (row_perm t A).

Lemma col_row_permC s t A :
  col_perm s (row_perm t A) = row_perm t (col_perm s A).

End FixedDim.

Local Notation "A ^T" := (trmx A) : ring_scope.

Lemma castmx_id m n erefl_mn (A : 'M_(m, n)) : castmx erefl_mn A = A.

Lemma castmx_comp m1 n1 m2 n2 m3 n3 (eq_m1 : m1 = m2) (eq_n1 : n1 = n2)
                                    (eq_m2 : m2 = m3) (eq_n2 : n2 = n3) A :
  castmx (eq_m2, eq_n2) (castmx (eq_m1, eq_n1) A)
    = castmx (etrans eq_m1 eq_m2, etrans eq_n1 eq_n2) A.

Lemma castmxK m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) :
  cancel (castmx (eq_m, eq_n)) (castmx (esym eq_m, esym eq_n)).

Lemma castmxKV m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) :
  cancel (castmx (esym eq_m, esym eq_n)) (castmx (eq_m, eq_n)).

This can be use to reverse an equation that involves a cast.
Lemma castmx_sym m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) A1 A2 :
  A1 = castmx (eq_m, eq_n) A2A2 = castmx (esym eq_m, esym eq_n) A1.

Lemma castmxE m1 n1 m2 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A i j :
  castmx eq_mn A i j =
     A (cast_ord (esym eq_mn.1) i) (cast_ord (esym eq_mn.2) j).

Lemma conform_mx_id m n (B A : 'M_(m, n)) : conform_mx B A = A.

Lemma nonconform_mx m m' n n' (B : 'M_(m', n')) (A : 'M_(m, n)) :
  (m != m') || (n != n')conform_mx B A = B.

Lemma conform_castmx m1 n1 m2 n2 m3 n3
                     (e_mn : (m2 = m3) × (n2 = n3)) (B : 'M_(m1, n1)) A :
  conform_mx B (castmx e_mn A) = conform_mx B A.

Lemma trmxK m n : cancel (@trmx m n) (@trmx n m).

Lemma trmx_inj m n : injective (@trmx m n).

Lemma trmx_cast m1 n1 m2 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A :
  (castmx eq_mn A)^T = castmx (eq_mn.2, eq_mn.1) A^T.

Lemma tr_row_perm m n s (A : 'M_(m, n)) : (row_perm s A)^T = col_perm s A^T.

Lemma tr_col_perm m n s (A : 'M_(m, n)) : (col_perm s A)^T = row_perm s A^T.

Lemma tr_xrow m n i1 i2 (A : 'M_(m, n)) : (xrow i1 i2 A)^T = xcol i1 i2 A^T.

Lemma tr_xcol m n j1 j2 (A : 'M_(m, n)) : (xcol j1 j2 A)^T = xrow j1 j2 A^T.

Lemma row_id n i (V : 'rV_n) : row i V = V.

Lemma col_id n j (V : 'cV_n) : col j V = V.

Lemma row_eq m1 m2 n i1 i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row i1 A1 = row i2 A2A1 i1 =1 A2 i2.

Lemma col_eq m n1 n2 j1 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col j1 A1 = col j2 A2A1^~ j1 =1 A2^~ j2.

Lemma row'_eq m n i0 (A B : 'M_(m, n)) :
  row' i0 A = row' i0 B{in predC1 i0, A =2 B}.

Lemma col'_eq m n j0 (A B : 'M_(m, n)) :
  col' j0 A = col' j0 B i, {in predC1 j0, A i =1 B i}.

Lemma tr_row m n i0 (A : 'M_(m, n)) : (row i0 A)^T = col i0 A^T.

Lemma tr_row' m n i0 (A : 'M_(m, n)) : (row' i0 A)^T = col' i0 A^T.

Lemma tr_col m n j0 (A : 'M_(m, n)) : (col j0 A)^T = row j0 A^T.

Lemma tr_col' m n j0 (A : 'M_(m, n)) : (col' j0 A)^T = row' j0 A^T.

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Section CutPaste.

Variables m m1 m2 n n1 n2 : nat.

Concatenating two matrices, in either direction.

Fact row_mx_key : unit.
Definition row_mx (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) : 'M[R]_(m, n1 + n2) :=
  \matrix[row_mx_key]_(i, j)
     match split j with inl j1A1 i j1 | inr j2A2 i j2 end.

Fact col_mx_key : unit.
Definition col_mx (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) : 'M[R]_(m1 + m2, n) :=
  \matrix[col_mx_key]_(i, j)
     match split i with inl i1A1 i1 j | inr i2A2 i2 j end.

Left/Right | Up/Down submatrices of a rows | columns matrix. The shape of the (dependent) width parameters of the type of A determines which submatrix is selected.

Fact lsubmx_key : unit.
Definition lsubmx (A : 'M[R]_(m, n1 + n2)) :=
  \matrix[lsubmx_key]_(i, j) A i (lshift n2 j).

Fact rsubmx_key : unit.
Definition rsubmx (A : 'M[R]_(m, n1 + n2)) :=
  \matrix[rsubmx_key]_(i, j) A i (rshift n1 j).

Fact usubmx_key : unit.
Definition usubmx (A : 'M[R]_(m1 + m2, n)) :=
  \matrix[usubmx_key]_(i, j) A (lshift m2 i) j.

Fact dsubmx_key : unit.
Definition dsubmx (A : 'M[R]_(m1 + m2, n)) :=
  \matrix[dsubmx_key]_(i, j) A (rshift m1 i) j.

Lemma row_mxEl A1 A2 i j : row_mx A1 A2 i (lshift n2 j) = A1 i j.

Lemma row_mxKl A1 A2 : lsubmx (row_mx A1 A2) = A1.

Lemma row_mxEr A1 A2 i j : row_mx A1 A2 i (rshift n1 j) = A2 i j.

Lemma row_mxKr A1 A2 : rsubmx (row_mx A1 A2) = A2.

Lemma hsubmxK A : row_mx (lsubmx A) (rsubmx A) = A.

Lemma col_mxEu A1 A2 i j : col_mx A1 A2 (lshift m2 i) j = A1 i j.

Lemma col_mxKu A1 A2 : usubmx (col_mx A1 A2) = A1.

Lemma col_mxEd A1 A2 i j : col_mx A1 A2 (rshift m1 i) j = A2 i j.

Lemma col_mxKd A1 A2 : dsubmx (col_mx A1 A2) = A2.

Lemma eq_row_mx A1 A2 B1 B2 : row_mx A1 A2 = row_mx B1 B2A1 = B1 A2 = B2.

Lemma eq_col_mx A1 A2 B1 B2 : col_mx A1 A2 = col_mx B1 B2A1 = B1 A2 = B2.

Lemma row_mx_const a : row_mx (const_mx a) (const_mx a) = const_mx a.

Lemma col_mx_const a : col_mx (const_mx a) (const_mx a) = const_mx a.

End CutPaste.

Lemma trmx_lsub m n1 n2 (A : 'M_(m, n1 + n2)) : (lsubmx A)^T = usubmx A^T.

Lemma trmx_rsub m n1 n2 (A : 'M_(m, n1 + n2)) : (rsubmx A)^T = dsubmx A^T.

Lemma tr_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  (row_mx A1 A2)^T = col_mx A1^T A2^T.

Lemma tr_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  (col_mx A1 A2)^T = row_mx A1^T A2^T.

Lemma trmx_usub m1 m2 n (A : 'M_(m1 + m2, n)) : (usubmx A)^T = lsubmx A^T.

Lemma trmx_dsub m1 m2 n (A : 'M_(m1 + m2, n)) : (dsubmx A)^T = rsubmx A^T.

Lemma vsubmxK m1 m2 n (A : 'M_(m1 + m2, n)) : col_mx (usubmx A) (dsubmx A) = A.

Lemma cast_row_mx m m' n1 n2 (eq_m : m = m') A1 A2 :
  castmx (eq_m, erefl _) (row_mx A1 A2)
    = row_mx (castmx (eq_m, erefl n1) A1) (castmx (eq_m, erefl n2) A2).

Lemma cast_col_mx m1 m2 n n' (eq_n : n = n') A1 A2 :
  castmx (erefl _, eq_n) (col_mx A1 A2)
    = col_mx (castmx (erefl m1, eq_n) A1) (castmx (erefl m2, eq_n) A2).

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma row_mxA m n1 n2 n3 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) (A3 : 'M_(m, n3)) :
  let cast := (erefl m, esym (addnA n1 n2 n3)) in
  row_mx A1 (row_mx A2 A3) = castmx cast (row_mx (row_mx A1 A2) A3).
Definition row_mxAx := row_mxA. (* bypass Prenex Implicits. *)

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma col_mxA m1 m2 m3 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) (A3 : 'M_(m3, n)) :
  let cast := (esym (addnA m1 m2 m3), erefl n) in
  col_mx A1 (col_mx A2 A3) = castmx cast (col_mx (col_mx A1 A2) A3).
Definition col_mxAx := col_mxA. (* bypass Prenex Implicits. *)

Lemma row_row_mx m n1 n2 i0 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  row i0 (row_mx A1 A2) = row_mx (row i0 A1) (row i0 A2).

Lemma col_col_mx m1 m2 n j0 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  col j0 (col_mx A1 A2) = col_mx (col j0 A1) (col j0 A2).

Lemma row'_row_mx m n1 n2 i0 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  row' i0 (row_mx A1 A2) = row_mx (row' i0 A1) (row' i0 A2).

Lemma col'_col_mx m1 m2 n j0 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  col' j0 (col_mx A1 A2) = col_mx (col' j0 A1) (col' j0 A2).

Lemma colKl m n1 n2 j1 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col (lshift n2 j1) (row_mx A1 A2) = col j1 A1.

Lemma colKr m n1 n2 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col (rshift n1 j2) (row_mx A1 A2) = col j2 A2.

Lemma rowKu m1 m2 n i1 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row (lshift m2 i1) (col_mx A1 A2) = row i1 A1.

Lemma rowKd m1 m2 n i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row (rshift m1 i2) (col_mx A1 A2) = row i2 A2.

Lemma col'Kl m n1 n2 j1 (A1 : 'M_(m, n1.+1)) (A2 : 'M_(m, n2)) :
  col' (lshift n2 j1) (row_mx A1 A2) = row_mx (col' j1 A1) A2.

Lemma row'Ku m1 m2 n i1 (A1 : 'M_(m1.+1, n)) (A2 : 'M_(m2, n)) :
  row' (lshift m2 i1) (@col_mx m1.+1 m2 n A1 A2) = col_mx (row' i1 A1) A2.

Lemma mx'_cast m n : 'I_n → (m + n.-1)%N = (m + n).-1.

Lemma col'Kr m n1 n2 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col' (rshift n1 j2) (@row_mx m n1 n2 A1 A2)
    = castmx (erefl m, mx'_cast n1 j2) (row_mx A1 (col' j2 A2)).

Lemma row'Kd m1 m2 n i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row' (rshift m1 i2) (col_mx A1 A2)
    = castmx (mx'_cast m1 i2, erefl n) (col_mx A1 (row' i2 A2)).

Section Block.

Variables m1 m2 n1 n2 : nat.

Building a block matrix from 4 matrices : up left, up right, down left and down right components

Definition block_mx Aul Aur Adl Adr : 'M_(m1 + m2, n1 + n2) :=
  col_mx (row_mx Aul Aur) (row_mx Adl Adr).

Lemma eq_block_mx Aul Aur Adl Adr Bul Bur Bdl Bdr :
 block_mx Aul Aur Adl Adr = block_mx Bul Bur Bdl Bdr
  [/\ Aul = Bul, Aur = Bur, Adl = Bdl & Adr = Bdr].

Lemma block_mx_const a :
  block_mx (const_mx a) (const_mx a) (const_mx a) (const_mx a) = const_mx a.

Section CutBlock.

Variable A : matrix R (m1 + m2) (n1 + n2).

Definition ulsubmx := lsubmx (usubmx A).
Definition ursubmx := rsubmx (usubmx A).
Definition dlsubmx := lsubmx (dsubmx A).
Definition drsubmx := rsubmx (dsubmx A).

Lemma submxK : block_mx ulsubmx ursubmx dlsubmx drsubmx = A.

End CutBlock.

Section CatBlock.

Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).

Let A := block_mx Aul Aur Adl Adr.

Lemma block_mxEul i j : A (lshift m2 i) (lshift n2 j) = Aul i j.
Lemma block_mxKul : ulsubmx A = Aul.

Lemma block_mxEur i j : A (lshift m2 i) (rshift n1 j) = Aur i j.
Lemma block_mxKur : ursubmx A = Aur.

Lemma block_mxEdl i j : A (rshift m1 i) (lshift n2 j) = Adl i j.
Lemma block_mxKdl : dlsubmx A = Adl.

Lemma block_mxEdr i j : A (rshift m1 i) (rshift n1 j) = Adr i j.
Lemma block_mxKdr : drsubmx A = Adr.

Lemma block_mxEv : A = col_mx (row_mx Aul Aur) (row_mx Adl Adr).

End CatBlock.

End Block.

Section TrCutBlock.

Variables m1 m2 n1 n2 : nat.
Variable A : 'M[R]_(m1 + m2, n1 + n2).

Lemma trmx_ulsub : (ulsubmx A)^T = ulsubmx A^T.

Lemma trmx_ursub : (ursubmx A)^T = dlsubmx A^T.

Lemma trmx_dlsub : (dlsubmx A)^T = ursubmx A^T.

Lemma trmx_drsub : (drsubmx A)^T = drsubmx A^T.

End TrCutBlock.

Section TrBlock.
Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).

Lemma tr_block_mx :
 (block_mx Aul Aur Adl Adr)^T = block_mx Aul^T Adl^T Aur^T Adr^T.

Lemma block_mxEh :
  block_mx Aul Aur Adl Adr = row_mx (col_mx Aul Adl) (col_mx Aur Adr).
End TrBlock.

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma block_mxA m1 m2 m3 n1 n2 n3
   (A11 : 'M_(m1, n1)) (A12 : 'M_(m1, n2)) (A13 : 'M_(m1, n3))
   (A21 : 'M_(m2, n1)) (A22 : 'M_(m2, n2)) (A23 : 'M_(m2, n3))
   (A31 : 'M_(m3, n1)) (A32 : 'M_(m3, n2)) (A33 : 'M_(m3, n3)) :
  let cast := (esym (addnA m1 m2 m3), esym (addnA n1 n2 n3)) in
  let row1 := row_mx A12 A13 in let col1 := col_mx A21 A31 in
  let row3 := row_mx A31 A32 in let col3 := col_mx A13 A23 in
  block_mx A11 row1 col1 (block_mx A22 A23 A32 A33)
    = castmx cast (block_mx (block_mx A11 A12 A21 A22) col3 row3 A33).
Definition block_mxAx := block_mxA. (* Bypass Prenex Implicits *)

Bijections mxvec : 'M_(m, n) <----> 'rV_(m * n) : vec_mx
Section VecMatrix.

Variables m n : nat.

Lemma mxvec_cast : #|{:'I_m × 'I_n}| = (m × n)%N.

Definition mxvec_index (i : 'I_m) (j : 'I_n) :=
  cast_ord mxvec_cast (enum_rank (i, j)).

CoInductive is_mxvec_index : 'I_(m × n)Type :=
  IsMxvecIndex i j : is_mxvec_index (mxvec_index i j).

Lemma mxvec_indexP k : is_mxvec_index k.

Coercion pair_of_mxvec_index k (i_k : is_mxvec_index k) :=
  let: IsMxvecIndex i j := i_k in (i, j).

Definition mxvec (A : 'M[R]_(m, n)) :=
  castmx (erefl _, mxvec_cast) (\row_k A (enum_val k).1 (enum_val k).2).

Fact vec_mx_key : unit.
Definition vec_mx (u : 'rV[R]_(m × n)) :=
  \matrix[vec_mx_key]_(i, j) u 0 (mxvec_index i j).

Lemma mxvecE A i j : mxvec A 0 (mxvec_index i j) = A i j.

Lemma mxvecK : cancel mxvec vec_mx.

Lemma vec_mxK : cancel vec_mx mxvec.

Lemma curry_mxvec_bij : {on 'I_(m × n), bijective (prod_curry mxvec_index)}.

End VecMatrix.

End MatrixStructural.

Implicit Arguments const_mx [R m n].
Implicit Arguments row_mxA [R m n1 n2 n3 A1 A2 A3].
Implicit Arguments col_mxA [R m1 m2 m3 n A1 A2 A3].
Implicit Arguments block_mxA
  [R m1 m2 m3 n1 n2 n3 A11 A12 A13 A21 A22 A23 A31 A32 A33].

Notation "A ^T" := (trmx A) : ring_scope.

Matrix parametricity.
Section MapMatrix.

Variables (aT rT : Type) (f : aTrT).

Fact map_mx_key : unit.
Definition map_mx m n (A : 'M_(m, n)) := \matrix[map_mx_key]_(i, j) f (A i j).

Notation "A ^f" := (map_mx A) : ring_scope.

Section OneMatrix.

Variables (m n : nat) (A : 'M[aT]_(m, n)).

Lemma map_trmx : A^f^T = A^T^f.

Lemma map_const_mx a : (const_mx a)^f = const_mx (f a) :> 'M_(m, n).

Lemma map_row i : (row i A)^f = row i A^f.

Lemma map_col j : (col j A)^f = col j A^f.

Lemma map_row' i0 : (row' i0 A)^f = row' i0 A^f.

Lemma map_col' j0 : (col' j0 A)^f = col' j0 A^f.

Lemma map_row_perm s : (row_perm s A)^f = row_perm s A^f.

Lemma map_col_perm s : (col_perm s A)^f = col_perm s A^f.

Lemma map_xrow i1 i2 : (xrow i1 i2 A)^f = xrow i1 i2 A^f.

Lemma map_xcol j1 j2 : (xcol j1 j2 A)^f = xcol j1 j2 A^f.

Lemma map_castmx m' n' c : (castmx c A)^f = castmx c A^f :> 'M_(m', n').

Lemma map_conform_mx m' n' (B : 'M_(m', n')) :
  (conform_mx B A)^f = conform_mx B^f A^f.

Lemma map_mxvec : (mxvec A)^f = mxvec A^f.

Lemma map_vec_mx (v : 'rV_(m × n)) : (vec_mx v)^f = vec_mx v^f.

End OneMatrix.

Section Block.

Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[aT]_(m1, n1)) (Aur : 'M[aT]_(m1, n2)).
Variables (Adl : 'M[aT]_(m2, n1)) (Adr : 'M[aT]_(m2, n2)).
Variables (Bh : 'M[aT]_(m1, n1 + n2)) (Bv : 'M[aT]_(m1 + m2, n1)).
Variable B : 'M[aT]_(m1 + m2, n1 + n2).

Lemma map_row_mx : (row_mx Aul Aur)^f = row_mx Aul^f Aur^f.

Lemma map_col_mx : (col_mx Aul Adl)^f = col_mx Aul^f Adl^f.

Lemma map_block_mx :
  (block_mx Aul Aur Adl Adr)^f = block_mx Aul^f Aur^f Adl^f Adr^f.

Lemma map_lsubmx : (lsubmx Bh)^f = lsubmx Bh^f.

Lemma map_rsubmx : (rsubmx Bh)^f = rsubmx Bh^f.

Lemma map_usubmx : (usubmx Bv)^f = usubmx Bv^f.

Lemma map_dsubmx : (dsubmx Bv)^f = dsubmx Bv^f.

Lemma map_ulsubmx : (ulsubmx B)^f = ulsubmx B^f.

Lemma map_ursubmx : (ursubmx B)^f = ursubmx B^f.

Lemma map_dlsubmx : (dlsubmx B)^f = dlsubmx B^f.

Lemma map_drsubmx : (drsubmx B)^f = drsubmx B^f.

End Block.

End MapMatrix.

Matrix Zmodule (additive) structure ******************

Section MatrixZmodule.

Variable V : zmodType.

Section FixedDim.

Variables m n : nat.
Implicit Types A B : 'M[V]_(m, n).

Fact oppmx_key : unit.
Fact addmx_key : unit.
Definition oppmx A := \matrix[oppmx_key]_(i, j) (- A i j).
Definition addmx A B := \matrix[addmx_key]_(i, j) (A i j + B i j).
In principle, diag_mx and scalar_mx could be defined here, but since they only make sense with the graded ring operations, we defer them to the next section.

Lemma addmxA : associative addmx.

Lemma addmxC : commutative addmx.

Lemma add0mx : left_id (const_mx 0) addmx.

Lemma addNmx : left_inverse (const_mx 0) oppmx addmx.

Definition matrix_zmodMixin := ZmodMixin addmxA addmxC add0mx addNmx.

Canonical matrix_zmodType := Eval hnf in ZmodType 'M[V]_(m, n) matrix_zmodMixin.

Lemma mulmxnE A d i j : (A *+ d) i j = A i j *+ d.

Lemma summxE I r (P : pred I) (E : I'M_(m, n)) i j :
  (\sum_(k <- r | P k) E k) i j = \sum_(k <- r | P k) E k i j.

Lemma const_mx_is_additive : additive const_mx.
Canonical const_mx_additive := Additive const_mx_is_additive.

End FixedDim.

Section Additive.

Variables (m n p q : nat) (f : 'I_p'I_q'I_m) (g : 'I_p'I_q'I_n).

Definition swizzle_mx k (A : 'M[V]_(m, n)) :=
  \matrix[k]_(i, j) A (f i j) (g i j).

Lemma swizzle_mx_is_additive k : additive (swizzle_mx k).
Canonical swizzle_mx_additive k := Additive (swizzle_mx_is_additive k).

End Additive.

Local Notation SwizzleAdd op := [additive of op as swizzle_mx _ _ _].

Canonical trmx_additive m n := SwizzleAdd (@trmx V m n).
Canonical row_additive m n i := SwizzleAdd (@row V m n i).
Canonical col_additive m n j := SwizzleAdd (@col V m n j).
Canonical row'_additive m n i := SwizzleAdd (@row' V m n i).
Canonical col'_additive m n j := SwizzleAdd (@col' V m n j).
Canonical row_perm_additive m n s := SwizzleAdd (@row_perm V m n s).
Canonical col_perm_additive m n s := SwizzleAdd (@col_perm V m n s).
Canonical xrow_additive m n i1 i2 := SwizzleAdd (@xrow V m n i1 i2).
Canonical xcol_additive m n j1 j2 := SwizzleAdd (@xcol V m n j1 j2).
Canonical lsubmx_additive m n1 n2 := SwizzleAdd (@lsubmx V m n1 n2).
Canonical rsubmx_additive m n1 n2 := SwizzleAdd (@rsubmx V m n1 n2).
Canonical usubmx_additive m1 m2 n := SwizzleAdd (@usubmx V m1 m2 n).
Canonical dsubmx_additive m1 m2 n := SwizzleAdd (@dsubmx V m1 m2 n).
Canonical vec_mx_additive m n := SwizzleAdd (@vec_mx V m n).
Canonical mxvec_additive m n :=
  Additive (can2_additive (@vec_mxK V m n) mxvecK).

Lemma flatmx0 n : all_equal_to (0 : 'M_(0, n)).

Lemma thinmx0 n : all_equal_to (0 : 'M_(n, 0)).

Lemma trmx0 m n : (0 : 'M_(m, n))^T = 0.

Lemma row0 m n i0 : row i0 (0 : 'M_(m, n)) = 0.

Lemma col0 m n j0 : col j0 (0 : 'M_(m, n)) = 0.

Lemma mxvec_eq0 m n (A : 'M_(m, n)) : (mxvec A == 0) = (A == 0).

Lemma vec_mx_eq0 m n (v : 'rV_(m × n)) : (vec_mx v == 0) = (v == 0).

Lemma row_mx0 m n1 n2 : row_mx 0 0 = 0 :> 'M_(m, n1 + n2).

Lemma col_mx0 m1 m2 n : col_mx 0 0 = 0 :> 'M_(m1 + m2, n).

Lemma block_mx0 m1 m2 n1 n2 : block_mx 0 0 0 0 = 0 :> 'M_(m1 + m2, n1 + n2).

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Lemma opp_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  - row_mx A1 A2 = row_mx (- A1) (- A2).

Lemma opp_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  - col_mx A1 A2 = col_mx (- A1) (- A2).

Lemma opp_block_mx m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2)) :
  - block_mx Aul Aur Adl Adr = block_mx (- Aul) (- Aur) (- Adl) (- Adr).

Lemma add_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) B1 B2 :
  row_mx A1 A2 + row_mx B1 B2 = row_mx (A1 + B1) (A2 + B2).

Lemma add_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) B1 B2 :
  col_mx A1 A2 + col_mx B1 B2 = col_mx (A1 + B1) (A2 + B2).

Lemma add_block_mx m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2))
                   Bul Bur Bdl Bdr :
  let A := block_mx Aul Aur Adl Adr in let B := block_mx Bul Bur Bdl Bdr in
  A + B = block_mx (Aul + Bul) (Aur + Bur) (Adl + Bdl) (Adr + Bdr).

Definition nz_row m n (A : 'M_(m, n)) :=
  oapp (fun irow i A) 0 [pick i | row i A != 0].

Lemma nz_row_eq0 m n (A : 'M_(m, n)) : (nz_row A == 0) = (A == 0).

End MatrixZmodule.

Section FinZmodMatrix.
Variables (V : finZmodType) (m n : nat).
Local Notation MV := 'M[V]_(m, n).

Canonical matrix_finZmodType := Eval hnf in [finZmodType of MV].
Canonical matrix_baseFinGroupType :=
  Eval hnf in [baseFinGroupType of MV for +%R].
Canonical matrix_finGroupType := Eval hnf in [finGroupType of MV for +%R].
End FinZmodMatrix.

Parametricity over the additive structure.
Section MapZmodMatrix.

Variables (aR rR : zmodType) (f : {additive aRrR}) (m n : nat).
Local Notation "A ^f" := (map_mx f A) : ring_scope.
Implicit Type A : 'M[aR]_(m, n).

Lemma map_mx0 : 0^f = 0 :> 'M_(m, n).

Lemma map_mxN A : (- A)^f = - A^f.

Lemma map_mxD A B : (A + B)^f = A^f + B^f.

Lemma map_mx_sub A B : (A - B)^f = A^f - B^f.

Definition map_mx_sum := big_morph _ map_mxD map_mx0.

Canonical map_mx_additive := Additive map_mx_sub.

End MapZmodMatrix.

Matrix ring module, graded ring, and ring structures ***********

Section MatrixAlgebra.

Variable R : ringType.

Section RingModule.

The ring module/vector space structure

Variables m n : nat.
Implicit Types A B : 'M[R]_(m, n).

Fact scalemx_key : unit.
Definition scalemx x A := \matrix[scalemx_key]_(i, j) (x × A i j).

Basis
Fact delta_mx_key : unit.
Definition delta_mx i0 j0 : 'M[R]_(m, n) :=
  \matrix[delta_mx_key]_(i, j) ((i == i0) && (j == j0))%:R.

Local Notation "x *m: A" := (scalemx x A) (at level 40) : ring_scope.

Lemma scale1mx A : 1 ×m: A = A.

Lemma scalemxDl A x y : (x + y) ×m: A = x ×m: A + y ×m: A.

Lemma scalemxDr x A B : x ×m: (A + B) = x ×m: A + x ×m: B.

Lemma scalemxA x y A : x ×m: (y ×m: A) = (x × y) ×m: A.

Definition matrix_lmodMixin :=
  LmodMixin scalemxA scale1mx scalemxDr scalemxDl.

Canonical matrix_lmodType :=
  Eval hnf in LmodType R 'M[R]_(m, n) matrix_lmodMixin.

Lemma scalemx_const a b : a *: const_mx b = const_mx (a × b).

Lemma matrix_sum_delta A :
  A = \sum_(i < m) \sum_(j < n) A i j *: delta_mx i j.

End RingModule.

Section StructuralLinear.

Lemma swizzle_mx_is_scalable m n p q f g k :
  scalable (@swizzle_mx R m n p q f g k).
Canonical swizzle_mx_scalable m n p q f g k :=
  AddLinear (@swizzle_mx_is_scalable m n p q f g k).

Local Notation SwizzleLin op := [linear of op as swizzle_mx _ _ _].

Canonical trmx_linear m n := SwizzleLin (@trmx R m n).
Canonical row_linear m n i := SwizzleLin (@row R m n i).
Canonical col_linear m n j := SwizzleLin (@col R m n j).
Canonical row'_linear m n i := SwizzleLin (@row' R m n i).
Canonical col'_linear m n j := SwizzleLin (@col' R m n j).
Canonical row_perm_linear m n s := SwizzleLin (@row_perm R m n s).
Canonical col_perm_linear m n s := SwizzleLin (@col_perm R m n s).
Canonical xrow_linear m n i1 i2 := SwizzleLin (@xrow R m n i1 i2).
Canonical xcol_linear m n j1 j2 := SwizzleLin (@xcol R m n j1 j2).
Canonical lsubmx_linear m n1 n2 := SwizzleLin (@lsubmx R m n1 n2).
Canonical rsubmx_linear m n1 n2 := SwizzleLin (@rsubmx R m n1 n2).
Canonical usubmx_linear m1 m2 n := SwizzleLin (@usubmx R m1 m2 n).
Canonical dsubmx_linear m1 m2 n := SwizzleLin (@dsubmx R m1 m2 n).
Canonical vec_mx_linear m n := SwizzleLin (@vec_mx R m n).
Definition mxvec_is_linear m n := can2_linear (@vec_mxK R m n) mxvecK.
Canonical mxvec_linear m n := AddLinear (@mxvec_is_linear m n).

End StructuralLinear.

Lemma trmx_delta m n i j : (delta_mx i j)^T = delta_mx j i :> 'M[R]_(n, m).

Lemma row_sum_delta n (u : 'rV_n) : u = \sum_(j < n) u 0 j *: delta_mx 0 j.

Lemma delta_mx_lshift m n1 n2 i j :
  delta_mx i (lshift n2 j) = row_mx (delta_mx i j) 0 :> 'M_(m, n1 + n2).

Lemma delta_mx_rshift m n1 n2 i j :
  delta_mx i (rshift n1 j) = row_mx 0 (delta_mx i j) :> 'M_(m, n1 + n2).

Lemma delta_mx_ushift m1 m2 n i j :
  delta_mx (lshift m2 i) j = col_mx (delta_mx i j) 0 :> 'M_(m1 + m2, n).

Lemma delta_mx_dshift m1 m2 n i j :
  delta_mx (rshift m1 i) j = col_mx 0 (delta_mx i j) :> 'M_(m1 + m2, n).

Lemma vec_mx_delta m n i j :
  vec_mx (delta_mx 0 (mxvec_index i j)) = delta_mx i j :> 'M_(m, n).

Lemma mxvec_delta m n i j :
  mxvec (delta_mx i j) = delta_mx 0 (mxvec_index i j) :> 'rV_(m × n).

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Lemma scale_row_mx m n1 n2 a (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  a *: row_mx A1 A2 = row_mx (a *: A1) (a *: A2).

Lemma scale_col_mx m1 m2 n a (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  a *: col_mx A1 A2 = col_mx (a *: A1) (a *: A2).

Lemma scale_block_mx m1 m2 n1 n2 a (Aul : 'M_(m1, n1)) (Aur : 'M_(m1, n2))
                                   (Adl : 'M_(m2, n1)) (Adr : 'M_(m2, n2)) :
  a *: block_mx Aul Aur Adl Adr
     = block_mx (a *: Aul) (a *: Aur) (a *: Adl) (a *: Adr).

Diagonal matrices

Fact diag_mx_key : unit.
Definition diag_mx n (d : 'rV[R]_n) :=
  \matrix[diag_mx_key]_(i, j) (d 0 i *+ (i == j)).

Lemma tr_diag_mx n (d : 'rV_n) : (diag_mx d)^T = diag_mx d.

Lemma diag_mx_is_linear n : linear (@diag_mx n).
Canonical diag_mx_additive n := Additive (@diag_mx_is_linear n).
Canonical diag_mx_linear n := Linear (@diag_mx_is_linear n).

Lemma diag_mx_sum_delta n (d : 'rV_n) :
  diag_mx d = \sum_i d 0 i *: delta_mx i i.

Scalar matrix : a diagonal matrix with a constant on the diagonal
Section ScalarMx.

Variable n : nat.

Fact scalar_mx_key : unit.
Definition scalar_mx x : 'M[R]_n :=
  \matrix[scalar_mx_key]_(i , j) (x *+ (i == j)).
Notation "x %:M" := (scalar_mx x) : ring_scope.

Lemma diag_const_mx a : diag_mx (const_mx a) = a%:M :> 'M_n.

Lemma tr_scalar_mx a : (a%:M)^T = a%:M.

Lemma trmx1 : (1%:M)^T = 1%:M.

Lemma scalar_mx_is_additive : additive scalar_mx.
Canonical scalar_mx_additive := Additive scalar_mx_is_additive.

Lemma scale_scalar_mx a1 a2 : a1 *: a2%:M = (a1 × a2)%:M :> 'M_n.

Lemma scalemx1 a : a *: 1%:M = a%:M.

Lemma scalar_mx_sum_delta a : a%:M = \sum_i a *: delta_mx i i.

Lemma mx1_sum_delta : 1%:M = \sum_i delta_mx i i.

Lemma row1 i : row i 1%:M = delta_mx 0 i.

Definition is_scalar_mx (A : 'M[R]_n) :=
  if insub 0%N is Some i then A == (A i i)%:M else true.

Lemma is_scalar_mxP A : reflect ( a, A = a%:M) (is_scalar_mx A).

Lemma scalar_mx_is_scalar a : is_scalar_mx a%:M.

Lemma mx0_is_scalar : is_scalar_mx 0.

End ScalarMx.

Notation "x %:M" := (scalar_mx _ x) : ring_scope.

Lemma mx11_scalar (A : 'M_1) : A = (A 0 0)%:M.

Lemma scalar_mx_block n1 n2 a : a%:M = block_mx a%:M 0 0 a%:M :> 'M_(n1 + n2).

Matrix multiplication using bigops.
Fact mulmx_key : unit.
Definition mulmx {m n p} (A : 'M_(m, n)) (B : 'M_(n, p)) : 'M[R]_(m, p) :=
  \matrix[mulmx_key]_(i, k) \sum_j (A i j × B j k).

Local Notation "A *m B" := (mulmx A B) : ring_scope.

Lemma mulmxA m n p q (A : 'M_(m, n)) (B : 'M_(n, p)) (C : 'M_(p, q)) :
  A ×m (B ×m C) = A ×m B ×m C.

Lemma mul0mx m n p (A : 'M_(n, p)) : 0 ×m A = 0 :> 'M_(m, p).

Lemma mulmx0 m n p (A : 'M_(m, n)) : A ×m 0 = 0 :> 'M_(m, p).

Lemma mulmxN m n p (A : 'M_(m, n)) (B : 'M_(n, p)) : A ×m (- B) = - (A ×m B).

Lemma mulNmx m n p (A : 'M_(m, n)) (B : 'M_(n, p)) : - A ×m B = - (A ×m B).

Lemma mulmxDl m n p (A1 A2 : 'M_(m, n)) (B : 'M_(n, p)) :
  (A1 + A2) ×m B = A1 ×m B + A2 ×m B.

Lemma mulmxDr m n p (A : 'M_(m, n)) (B1 B2 : 'M_(n, p)) :
  A ×m (B1 + B2) = A ×m B1 + A ×m B2.

Lemma mulmxBl m n p (A1 A2 : 'M_(m, n)) (B : 'M_(n, p)) :
  (A1 - A2) ×m B = A1 ×m B - A2 ×m B.

Lemma mulmxBr m n p (A : 'M_(m, n)) (B1 B2 : 'M_(n, p)) :
  A ×m (B1 - B2) = A ×m B1 - A ×m B2.

Lemma mulmx_suml m n p (A : 'M_(n, p)) I r P (B_ : I'M_(m, n)) :
   (\sum_(i <- r | P i) B_ i) ×m A = \sum_(i <- r | P i) B_ i ×m A.

Lemma mulmx_sumr m n p (A : 'M_(m, n)) I r P (B_ : I'M_(n, p)) :
   A ×m (\sum_(i <- r | P i) B_ i) = \sum_(i <- r | P i) A ×m B_ i.

Lemma scalemxAl m n p a (A : 'M_(m, n)) (B : 'M_(n, p)) :
  a *: (A ×m B) = (a *: A) ×m B.
Right scaling associativity requires a commutative ring

Lemma rowE m n i (A : 'M_(m, n)) : row i A = delta_mx 0 i ×m A.

Lemma row_mul m n p (i : 'I_m) A (B : 'M_(n, p)) :
  row i (A ×m B) = row i A ×m B.

Lemma mulmx_sum_row m n (u : 'rV_m) (A : 'M_(m, n)) :
  u ×m A = \sum_i u 0 i *: row i A.

Lemma mul_delta_mx_cond m n p (j1 j2 : 'I_n) (i1 : 'I_m) (k2 : 'I_p) :
  delta_mx i1 j1 ×m delta_mx j2 k2 = delta_mx i1 k2 *+ (j1 == j2).

Lemma mul_delta_mx m n p (j : 'I_n) (i : 'I_m) (k : 'I_p) :
  delta_mx i j ×m delta_mx j k = delta_mx i k.

Lemma mul_delta_mx_0 m n p (j1 j2 : 'I_n) (i1 : 'I_m) (k2 : 'I_p) :
  j1 != j2delta_mx i1 j1 ×m delta_mx j2 k2 = 0.

Lemma mul_diag_mx m n d (A : 'M_(m, n)) :
  diag_mx d ×m A = \matrix_(i, j) (d 0 i × A i j).

Lemma mul_mx_diag m n (A : 'M_(m, n)) d :
  A ×m diag_mx d = \matrix_(i, j) (A i j × d 0 j).

Lemma mulmx_diag n (d e : 'rV_n) :
  diag_mx d ×m diag_mx e = diag_mx (\row_j (d 0 j × e 0 j)).

Lemma mul_scalar_mx m n a (A : 'M_(m, n)) : a%:M ×m A = a *: A.

Lemma scalar_mxM n a b : (a × b)%:M = a%:M ×m b%:M :> 'M_n.

Lemma mul1mx m n (A : 'M_(m, n)) : 1%:M ×m A = A.

Lemma mulmx1 m n (A : 'M_(m, n)) : A ×m 1%:M = A.

Lemma mul_col_perm m n p s (A : 'M_(m, n)) (B : 'M_(n, p)) :
  col_perm s A ×m B = A ×m row_perm s^-1 B.

Lemma mul_row_perm m n p s (A : 'M_(m, n)) (B : 'M_(n, p)) :
  A ×m row_perm s B = col_perm s^-1 A ×m B.

Lemma mul_xcol m n p j1 j2 (A : 'M_(m, n)) (B : 'M_(n, p)) :
  xcol j1 j2 A ×m B = A ×m xrow j1 j2 B.

Permutation matrix

Definition perm_mx n s : 'M_n := row_perm s 1%:M.

Definition tperm_mx n i1 i2 : 'M_n := perm_mx (tperm i1 i2).

Lemma col_permE m n s (A : 'M_(m, n)) : col_perm s A = A ×m perm_mx s^-1.

Lemma row_permE m n s (A : 'M_(m, n)) : row_perm s A = perm_mx s ×m A.

Lemma xcolE m n j1 j2 (A : 'M_(m, n)) : xcol j1 j2 A = A ×m tperm_mx j1 j2.

Lemma xrowE m n i1 i2 (A : 'M_(m, n)) : xrow i1 i2 A = tperm_mx i1 i2 ×m A.

Lemma tr_perm_mx n (s : 'S_n) : (perm_mx s)^T = perm_mx s^-1.

Lemma tr_tperm_mx n i1 i2 : (tperm_mx i1 i2)^T = tperm_mx i1 i2 :> 'M_n.

Lemma perm_mx1 n : perm_mx 1 = 1%:M :> 'M_n.

Lemma perm_mxM n (s t : 'S_n) : perm_mx (s × t) = perm_mx s ×m perm_mx t.

Definition is_perm_mx n (A : 'M_n) := [ s, A == perm_mx s].

Lemma is_perm_mxP n (A : 'M_n) :
  reflect ( s, A = perm_mx s) (is_perm_mx A).

Lemma perm_mx_is_perm n (s : 'S_n) : is_perm_mx (perm_mx s).

Lemma is_perm_mx1 n : is_perm_mx (1%:M : 'M_n).

Lemma is_perm_mxMl n (A B : 'M_n) :
  is_perm_mx Ais_perm_mx (A ×m B) = is_perm_mx B.

Lemma is_perm_mx_tr n (A : 'M_n) : is_perm_mx A^T = is_perm_mx A.

Lemma is_perm_mxMr n (A B : 'M_n) :
  is_perm_mx Bis_perm_mx (A ×m B) = is_perm_mx A.

Partial identity matrix (used in rank decomposition).

Fact pid_mx_key : unit.
Definition pid_mx {m n} r : 'M[R]_(m, n) :=
  \matrix[pid_mx_key]_(i, j) ((i == j :> nat) && (i < r))%:R.

Lemma pid_mx_0 m n : pid_mx 0 = 0 :> 'M_(m, n).

Lemma pid_mx_1 r : pid_mx r = 1%:M :> 'M_r.

Lemma pid_mx_row n r : pid_mx r = row_mx 1%:M 0 :> 'M_(r, r + n).

Lemma pid_mx_col m r : pid_mx r = col_mx 1%:M 0 :> 'M_(r + m, r).

Lemma pid_mx_block m n r : pid_mx r = block_mx 1%:M 0 0 0 :> 'M_(r + m, r + n).

Lemma tr_pid_mx m n r : (pid_mx r)^T = pid_mx r :> 'M_(n, m).

Lemma pid_mx_minv m n r : pid_mx (minn m r) = pid_mx r :> 'M_(m, n).

Lemma pid_mx_minh m n r : pid_mx (minn n r) = pid_mx r :> 'M_(m, n).

Lemma mul_pid_mx m n p q r :
  (pid_mx q : 'M_(m, n)) ×m (pid_mx r : 'M_(n, p)) = pid_mx (minn n (minn q r)).

Lemma pid_mx_id m n p r :
  r n(pid_mx r : 'M_(m, n)) ×m (pid_mx r : 'M_(n, p)) = pid_mx r.

Definition copid_mx {n} r : 'M_n := 1%:M - pid_mx r.

Lemma mul_copid_mx_pid m n r :
  r mcopid_mx r ×m pid_mx r = 0 :> 'M_(m, n).

Lemma mul_pid_mx_copid m n r :
  r npid_mx r ×m copid_mx r = 0 :> 'M_(m, n).

Lemma copid_mx_id n r :
  r ncopid_mx r ×m copid_mx r = copid_mx r :> 'M_n.

Block products; we cover all 1 x 2, 2 x 1, and 2 x 2 block products.
Lemma mul_mx_row m n p1 p2 (A : 'M_(m, n)) (Bl : 'M_(n, p1)) (Br : 'M_(n, p2)) :
  A ×m row_mx Bl Br = row_mx (A ×m Bl) (A ×m Br).

Lemma mul_col_mx m1 m2 n p (Au : 'M_(m1, n)) (Ad : 'M_(m2, n)) (B : 'M_(n, p)) :
  col_mx Au Ad ×m B = col_mx (Au ×m B) (Ad ×m B).

Lemma mul_row_col m n1 n2 p (Al : 'M_(m, n1)) (Ar : 'M_(m, n2))
                            (Bu : 'M_(n1, p)) (Bd : 'M_(n2, p)) :
  row_mx Al Ar ×m col_mx Bu Bd = Al ×m Bu + Ar ×m Bd.

Lemma mul_col_row m1 m2 n p1 p2 (Au : 'M_(m1, n)) (Ad : 'M_(m2, n))
                                (Bl : 'M_(n, p1)) (Br : 'M_(n, p2)) :
  col_mx Au Ad ×m row_mx Bl Br
     = block_mx (Au ×m Bl) (Au ×m Br) (Ad ×m Bl) (Ad ×m Br).

Lemma mul_row_block m n1 n2 p1 p2 (Al : 'M_(m, n1)) (Ar : 'M_(m, n2))
                                  (Bul : 'M_(n1, p1)) (Bur : 'M_(n1, p2))
                                  (Bdl : 'M_(n2, p1)) (Bdr : 'M_(n2, p2)) :
  row_mx Al Ar ×m block_mx Bul Bur Bdl Bdr
   = row_mx (Al ×m Bul + Ar ×m Bdl) (Al ×m Bur + Ar ×m Bdr).

Lemma mul_block_col m1 m2 n1 n2 p (Aul : 'M_(m1, n1)) (Aur : 'M_(m1, n2))
                                  (Adl : 'M_(m2, n1)) (Adr : 'M_(m2, n2))
                                  (Bu : 'M_(n1, p)) (Bd : 'M_(n2, p)) :
  block_mx Aul Aur Adl Adr ×m col_mx Bu Bd
   = col_mx (Aul ×m Bu + Aur ×m Bd) (Adl ×m Bu + Adr ×m Bd).

Lemma mulmx_block m1 m2 n1 n2 p1 p2 (Aul : 'M_(m1, n1)) (Aur : 'M_(m1, n2))
                                    (Adl : 'M_(m2, n1)) (Adr : 'M_(m2, n2))
                                    (Bul : 'M_(n1, p1)) (Bur : 'M_(n1, p2))
                                    (Bdl : 'M_(n2, p1)) (Bdr : 'M_(n2, p2)) :
  block_mx Aul Aur Adl Adr ×m block_mx Bul Bur Bdl Bdr
    = block_mx (Aul ×m Bul + Aur ×m Bdl) (Aul ×m Bur + Aur ×m Bdr)
               (Adl ×m Bul + Adr ×m Bdl) (Adl ×m Bur + Adr ×m Bdr).

Correspondance between matrices and linear function on row vectors.
Section LinRowVector.

Variables m n : nat.

Fact lin1_mx_key : unit.
Definition lin1_mx (f : 'rV[R]_m'rV[R]_n) :=
  \matrix[lin1_mx_key]_(i, j) f (delta_mx 0 i) 0 j.

Variable f : {linear 'rV[R]_m'rV[R]_n}.

Lemma mul_rV_lin1 u : u ×m lin1_mx f = f u.

End LinRowVector.

Correspondance between matrices and linear function on matrices.
Section LinMatrix.

Variables m1 n1 m2 n2 : nat.

Definition lin_mx (f : 'M[R]_(m1, n1)'M[R]_(m2, n2)) :=
  lin1_mx (mxvec \o f \o vec_mx).

Variable f : {linear 'M[R]_(m1, n1)'M[R]_(m2, n2)}.

Lemma mul_rV_lin u : u ×m lin_mx f = mxvec (f (vec_mx u)).

Lemma mul_vec_lin A : mxvec A ×m lin_mx f = mxvec (f A).

Lemma mx_rV_lin u : vec_mx (u ×m lin_mx f) = f (vec_mx u).

Lemma mx_vec_lin A : vec_mx (mxvec A ×m lin_mx f) = f A.

End LinMatrix.

Canonical mulmx_additive m n p A := Additive (@mulmxBr m n p A).

Section Mulmxr.

Variables m n p : nat.
Implicit Type A : 'M[R]_(m, n).
Implicit Type B : 'M[R]_(n, p).

Definition mulmxr_head t B A := let: tt := t in A ×m B.
Local Notation mulmxr := (mulmxr_head tt).

Definition lin_mulmxr B := lin_mx (mulmxr B).

Lemma mulmxr_is_linear B : linear (mulmxr B).
Canonical mulmxr_additive B := Additive (mulmxr_is_linear B).
Canonical mulmxr_linear B := Linear (mulmxr_is_linear B).

Lemma lin_mulmxr_is_linear : linear lin_mulmxr.
Canonical lin_mulmxr_additive := Additive lin_mulmxr_is_linear.
Canonical lin_mulmxr_linear := Linear lin_mulmxr_is_linear.

End Mulmxr.

The trace.
Section Trace.

Variable n : nat.

Definition mxtrace (A : 'M[R]_n) := \sum_i A i i.
Local Notation "'\tr' A" := (mxtrace A) : ring_scope.

Lemma mxtrace_tr A : \tr A^T = \tr A.

Lemma mxtrace_is_scalar : scalar mxtrace.
Canonical mxtrace_additive := Additive mxtrace_is_scalar.
Canonical mxtrace_linear := Linear mxtrace_is_scalar.

Lemma mxtrace0 : \tr 0 = 0.
Lemma mxtraceD A B : \tr (A + B) = \tr A + \tr B.
Lemma mxtraceZ a A : \tr (a *: A) = a × \tr A.

Lemma mxtrace_diag D : \tr (diag_mx D) = \sum_j D 0 j.

Lemma mxtrace_scalar a : \tr a%:M = a *+ n.

Lemma mxtrace1 : \tr 1%:M = n%:R.

End Trace.
Local Notation "'\tr' A" := (mxtrace A) : ring_scope.

Lemma trace_mx11 (A : 'M_1) : \tr A = A 0 0.

Lemma mxtrace_block n1 n2 (Aul : 'M_n1) Aur Adl (Adr : 'M_n2) :
  \tr (block_mx Aul Aur Adl Adr) = \tr Aul + \tr Adr.

The matrix ring structure requires a strutural condition (dimension of the form n.+1) to statisfy the nontriviality condition we have imposed.
Section MatrixRing.

Variable n' : nat.
Local Notation n := n'.+1.

Lemma matrix_nonzero1 : 1%:M != 0 :> 'M_n.

Definition matrix_ringMixin :=
  RingMixin (@mulmxA n n n n) (@mul1mx n n) (@mulmx1 n n)
            (@mulmxDl n n n) (@mulmxDr n n n) matrix_nonzero1.

Canonical matrix_ringType := Eval hnf in RingType 'M[R]_n matrix_ringMixin.
Canonical matrix_lAlgType := Eval hnf in LalgType R 'M[R]_n (@scalemxAl n n n).

Lemma mulmxE : mulmx = *%R.
Lemma idmxE : 1%:M = 1 :> 'M_n.

Lemma scalar_mx_is_multiplicative : multiplicative (@scalar_mx n).
Canonical scalar_mx_rmorphism := AddRMorphism scalar_mx_is_multiplicative.

End MatrixRing.

Section LiftPerm.

Block expresssion of a lifted permutation matrix, for the Cormen LUP.

Variable n : nat.

These could be in zmodp, but that would introduce a dependency on perm.

Definition lift0_perm s : 'S_n.+1 := lift_perm 0 0 s.

Lemma lift0_perm0 s : lift0_perm s 0 = 0.

Lemma lift0_perm_lift s k' :
  lift0_perm s (lift 0 k') = lift (0 : 'I_n.+1) (s k').

Lemma lift0_permK s : cancel (lift0_perm s) (lift0_perm s^-1).

Lemma lift0_perm_eq0 s i : (lift0_perm s i == 0) = (i == 0).

Block expresssion of a lifted permutation matrix

Definition lift0_mx A : 'M_(1 + n) := block_mx 1 0 0 A.

Lemma lift0_mx_perm s : lift0_mx (perm_mx s) = perm_mx (lift0_perm s).

Lemma lift0_mx_is_perm s : is_perm_mx (lift0_mx (perm_mx s)).

End LiftPerm.

Determinants and adjugates are defined here, but most of their properties only hold for matrices over a commutative ring, so their theory is deferred to that section.
The determinant, in one line with the Leibniz Formula
Definition determinant n (A : 'M_n) : R :=
  \sum_(s : 'S_n) (-1) ^+ s × \prod_i A i (s i).

The cofactor of a matrix on the indexes i and j
Definition cofactor n A (i j : 'I_n) : R :=
  (-1) ^+ (i + j) × determinant (row' i (col' j A)).

The adjugate matrix : defined as the transpose of the matrix of cofactors
Fact adjugate_key : unit.
Definition adjugate n (A : 'M_n) := \matrix[adjugate_key]_(i, j) cofactor A j i.

End MatrixAlgebra.

Implicit Arguments delta_mx [R m n].
Implicit Arguments scalar_mx [R n].
Implicit Arguments perm_mx [R n].
Implicit Arguments tperm_mx [R n].
Implicit Arguments pid_mx [R m n].
Implicit Arguments copid_mx [R n].
Implicit Arguments lin_mulmxr [R m n p].

Implicit Arguments is_scalar_mxP [R n A].
Implicit Arguments mul_delta_mx [R m n p].

Notation "a %:M" := (scalar_mx a) : ring_scope.
Notation "A *m B" := (mulmx A B) : ring_scope.
Notation mulmxr := (mulmxr_head tt).
Notation "\tr A" := (mxtrace A) : ring_scope.
Notation "'\det' A" := (determinant A) : ring_scope.
Notation "'\adj' A" := (adjugate A) : ring_scope.

Non-commutative transpose requires multiplication in the converse ring.
Lemma trmx_mul_rev (R : ringType) m n p (A : 'M[R]_(m, n)) (B : 'M[R]_(n, p)) :
  (A ×m B)^T = (B : 'M[R^c]_(n, p))^T ×m (A : 'M[R^c]_(m, n))^T.

Canonical matrix_finRingType (R : finRingType) n' :=
  Eval hnf in [finRingType of 'M[R]_n'.+1].

Parametricity over the algebra structure.
Section MapRingMatrix.

Variables (aR rR : ringType) (f : {rmorphism aRrR}).
Local Notation "A ^f" := (map_mx f A) : ring_scope.

Section FixedSize.

Variables m n p : nat.
Implicit Type A : 'M[aR]_(m, n).

Lemma map_mxZ a A : (a *: A)^f = f a *: A^f.

Lemma map_mxM A B : (A ×m B)^f = A^f ×m B^f :> 'M_(m, p).

Lemma map_delta_mx i j : (delta_mx i j)^f = delta_mx i j :> 'M_(m, n).

Lemma map_diag_mx d : (diag_mx d)^f = diag_mx d^f :> 'M_n.

Lemma map_scalar_mx a : a%:M^f = (f a)%:M :> 'M_n.

Lemma map_mx1 : 1%:M^f = 1%:M :> 'M_n.

Lemma map_perm_mx (s : 'S_n) : (perm_mx s)^f = perm_mx s.

Lemma map_tperm_mx (i1 i2 : 'I_n) : (tperm_mx i1 i2)^f = tperm_mx i1 i2.

Lemma map_pid_mx r : (pid_mx r)^f = pid_mx r :> 'M_(m, n).

Lemma trace_map_mx (A : 'M_n) : \tr A^f = f (\tr A).

Lemma det_map_mx n' (A : 'M_n') : \det A^f = f (\det A).

Lemma cofactor_map_mx (A : 'M_n) i j : cofactor A^f i j = f (cofactor A i j).

Lemma map_mx_adj (A : 'M_n) : (\adj A)^f = \adj A^f.

End FixedSize.

Lemma map_copid_mx n r : (copid_mx r)^f = copid_mx r :> 'M_n.

Lemma map_mx_is_multiplicative n' (n := n'.+1) :
  multiplicative ((map_mx f) n n).

Canonical map_mx_rmorphism n' := AddRMorphism (map_mx_is_multiplicative n').

Lemma map_lin1_mx m n (g : 'rV_m'rV_n) gf :
  ( v, (g v)^f = gf v^f) → (lin1_mx g)^f = lin1_mx gf.

Lemma map_lin_mx m1 n1 m2 n2 (g : 'M_(m1, n1)'M_(m2, n2)) gf :
  ( A, (g A)^f = gf A^f) → (lin_mx g)^f = lin_mx gf.

End MapRingMatrix.

Section ComMatrix.
Lemmas for matrices with coefficients in a commutative ring
Variable R : comRingType.

Section AssocLeft.

Variables m n p : nat.
Implicit Type A : 'M[R]_(m, n).
Implicit Type B : 'M[R]_(n, p).

Lemma trmx_mul A B : (A ×m B)^T = B^T ×m A^T.

Lemma scalemxAr a A B : a *: (A ×m B) = A ×m (a *: B).

Lemma mulmx_is_scalable A : scalable (@mulmx _ m n p A).
Canonical mulmx_linear A := AddLinear (mulmx_is_scalable A).

Definition lin_mulmx A : 'M[R]_(n × p, m × p) := lin_mx (mulmx A).

Lemma lin_mulmx_is_linear : linear lin_mulmx.
Canonical lin_mulmx_additive := Additive lin_mulmx_is_linear.
Canonical lin_mulmx_linear := Linear lin_mulmx_is_linear.

End AssocLeft.

Section LinMulRow.

Variables m n : nat.

Definition lin_mul_row u : 'M[R]_(m × n, n) := lin1_mx (mulmx u \o vec_mx).

Lemma lin_mul_row_is_linear : linear lin_mul_row.
Canonical lin_mul_row_additive := Additive lin_mul_row_is_linear.
Canonical lin_mul_row_linear := Linear lin_mul_row_is_linear.

Lemma mul_vec_lin_row A u : mxvec A ×m lin_mul_row u = u ×m A.

End LinMulRow.

Lemma mxvec_dotmul m n (A : 'M[R]_(m, n)) u v :
  mxvec (u^T ×m v) ×m (mxvec A)^T = u ×m A ×m v^T.

Section MatrixAlgType.

Variable n' : nat.
Local Notation n := n'.+1.

Canonical matrix_algType :=
  Eval hnf in AlgType R 'M[R]_n (fun kscalemxAr k).

End MatrixAlgType.

Lemma diag_mxC n (d e : 'rV[R]_n) :
  diag_mx d ×m diag_mx e = diag_mx e ×m diag_mx d.

Lemma diag_mx_comm n' (d e : 'rV[R]_n'.+1) : GRing.comm (diag_mx d) (diag_mx e).

Lemma scalar_mxC m n a (A : 'M[R]_(m, n)) : A ×m a%:M = a%:M ×m A.

Lemma scalar_mx_comm n' a (A : 'M[R]_n'.+1) : GRing.comm A a%:M.

Lemma mul_mx_scalar m n a (A : 'M[R]_(m, n)) : A ×m a%:M = a *: A.

Lemma mxtrace_mulC m n (A : 'M[R]_(m, n)) (B : 'M_(n, m)) :
  \tr (A ×m B) = \tr (B ×m A).

The theory of determinants

Lemma determinant_multilinear n (A B C : 'M[R]_n) i0 b c :
    row i0 A = b *: row i0 B + c *: row i0 C
    row' i0 B = row' i0 A
    row' i0 C = row' i0 A
  \det A = b × \det B + c × \det C.

Lemma determinant_alternate n (A : 'M[R]_n) i1 i2 :
  i1 != i2A i1 =1 A i2\det A = 0.

Lemma det_tr n (A : 'M[R]_n) : \det A^T = \det A.

Lemma det_perm n (s : 'S_n) : \det (perm_mx s) = (-1) ^+ s :> R.

Lemma det1 n : \det (1%:M : 'M[R]_n) = 1.

Lemma det_mx00 (A : 'M[R]_0) : \det A = 1.

Lemma detZ n a (A : 'M[R]_n) : \det (a *: A) = a ^+ n × \det A.

Lemma det0 n' : \det (0 : 'M[R]_n'.+1) = 0.

Lemma det_scalar n a : \det (a%:M : 'M[R]_n) = a ^+ n.

Lemma det_scalar1 a : \det (a%:M : 'M[R]_1) = a.

Lemma det_mulmx n (A B : 'M[R]_n) : \det (A ×m B) = \det A × \det B.

Lemma detM n' (A B : 'M[R]_n'.+1) : \det (A × B) = \det A × \det B.

Lemma det_diag n (d : 'rV[R]_n) : \det (diag_mx d) = \prod_i d 0 i.

Laplace expansion lemma
Lemma expand_cofactor n (A : 'M[R]_n) i j :
  cofactor A i j =
    \sum_(s : 'S_n | s i == j) (-1) ^+ s × \prod_(k | i != k) A k (s k).

Lemma expand_det_row n (A : 'M[R]_n) i0 :
  \det A = \sum_j A i0 j × cofactor A i0 j.

Lemma cofactor_tr n (A : 'M[R]_n) i j : cofactor A^T i j = cofactor A j i.

Lemma cofactorZ n a (A : 'M[R]_n) i j :
  cofactor (a *: A) i j = a ^+ n.-1 × cofactor A i j.

Lemma