
#include <stdio.h>
#include "matrix.h"

/* This file contains functions for computing inverse of a square matrix:
 * 	A function for lower triangular matrix
 * 	A function for upper triangular matrix
 * 	A function for permutation matrix
 * 	A function for general matrix
 * 	A function for computing LUP decomposition of a matrix
 */



/* Computes the LUP decomposition of size n matrix
 * stored in M using recursion. if M if invertible,
 * the LUP decomposition gives three matrices: 
 * L, U and P such that M = L * U * P,
 * and L is a lower triangular matrix with its diagonal all 1's,
 * U is an upper triangular matrix, and P is a permutation matrix.
 * The function also returns the determinant of M.
 *
 * The algorithm used is as follows. First M is copied into A.
 * Then a permutation matrix tildeP is constructed such that
 * A = B * tildeP and the first element of first column of B is non-zero.
 * After that B is split into tildeL * C where tildeL is a lower-triangular
 * matrix with all 1's on diagonal and in C all elements in first column
 * except the first one are zero.
 * (A is modfied successively to become B and then C in the program.)
 * C is then copied in Cprime after dropping first row and column.
 * Cprime is recursively decomposed into Lprime * Uprime * Pprime.
 * Matrix L is the product of tildeL and Lprime (extended to size of M);
 * matrix U is the matrix Uprime (extended to size of M by adding the first
 * row and column of C to it after shuffling the first row accoring to
 * Pprime-transpose; and matrix P is the product of Pprime (extended to size of M)
 * and tildeP.
 */
float LUP_decompose(Matrix M, Matrix L, Matrix U, Matrix P)
{
	int i;
	float det_value; // stores the value of determinant
	Matrix A; // matrix initialized to M
	Matrix tildeP; // A = B * tildeP, B has top-left element non-zero
	Matrix tildeL; // A = tildeL * C * tildeP, C has the first column all zero except the first element
	Matrix Cprime; // equals C minus the first row and column
	Matrix Lprime; // Cprime = Lprime * Uprime * Pprime
	Matrix Uprime;
	Matrix Pprime;

	if (M.rows == 1) { // 1 x 1 matrix
		L.elements[0][0] = 1;
		U.elements[0][0] = M.elements[0][0];
		P.elements[0][0] = 1;
		return M.elements[0][0];
	}

	A = allocate_matrix(M.rows, M.cols); // allocate space for A
	copy_matrix(A, 0, M, 0, M.rows, M.cols); // copy M to A

	// Initialize tildeP to identity matrix
	tildeP = allocate_matrix(M.rows, M.cols);
	set_to_identity(tildeP);

	if (A.elements[0][0] == 0) { 
	// first element of first column zero: find a column with non-zero first element and swap
		i = find_nonzero_column(A);
		if (i >= A.cols) // no non-zero column
			return 0; // determinant is 0
		swap_column(A, 0, i); // swap column #i with column #0, the "new A" equals B now

		// Set matrix tildeP such that "old A" = "new A" * tildeP
		tildeP.elements[0][0] = tildeP.elements[i][i] = 0;
		tildeP.elements[0][i] = tildeP.elements[i][0] = 1;
	}

	/* Construct the lower traingular matrix tildeL. Simultaneously, also
	 * make the first column of A all zero except the first element.
	 * After this, A becomes C.
	 */
	// First set tildeL to identity
	tildeL = allocate_matrix(A.rows, A.cols);
	set_to_identity(tildeL);
	
	for (int t = 1; t < A.rows; t++) { 
		tildeL.elements[t][0] = A.elements[t][0] / A.elements[0][0]; // record the factor in tildeL[t][0]
		add_row(A.elements[t], A.elements[0], -tildeL.elements[t][0], A.cols); // make first column of A[t] zero
	}
	// At this point, "new A" is C. And "new A" = tildeL * "old A"

	// Drop the first row and column of A
	Cprime = allocate_matrix(A.rows-1, A.cols-1);
	copy_matrix(Cprime, 0, A, 1, A.rows-1, A.cols-1);

	// Recursively decompose Cprime
	Lprime = allocate_matrix(A.rows-1, A.cols-1);
	Uprime = allocate_matrix(A.rows-1, A.cols-1);
	Pprime = allocate_matrix(A.rows-1, A.cols-1);

	det_value = LUP_decompose(Cprime, Lprime, Uprime, Pprime); // Cprime = Lprime * Uprime * Pprime

	if (is_zero(det_value)) { // determinant is zero, so no inverse exists
		// free up space
		free_matrix(A);
		free_matrix(tildeL);
		free_matrix(tildeP);
		free_matrix(Cprime);
		free_matrix(Lprime);
		free_matrix(Uprime);
		free_matrix(Pprime);
		return 0.0;
	}
	
	// Compute L = tildeL * (Lprime extended to size of A)
	// First extend Lprime to size of A and store in L
	set_to_identity(L);
	copy_matrix(L, 1, Lprime, 0, Lprime.rows, Lprime.cols);
	// Now multiply tildeL to L and store in L
	multiply_matrix(tildeL, L, &L); 

	// Compute U = (Uprime extended to size of A by adding first row shuffled by Pprime-transpose)
	set_to_identity(U);
	U.elements[0][0] = A.elements[0][0];
	for (int k = 1; k < U.rows; k++) {
		for (int t = 1; t < U.cols; t++)
			U.elements[0][k] = U.elements[0][k] + A.elements[0][t] * Pprime.elements[k-1][t-1];
	}
	copy_matrix(U, 1, Uprime, 0, Uprime.rows, Uprime.cols);

	// Compute P = (Pprime extended to size of A) * tildeP
	set_to_identity(P);
	copy_matrix(P, 1, Pprime, 0, Pprime.rows, Pprime.cols);
	multiply_matrix(P, tildeP, &P);

	/* Finally, compute the determinant: since the determinant of tildeL is one,
	 * the determinant of A equals (A[0][0] * det_value) * (determinant of tildeP).
	 */
	det_value =  det_value * A.elements[0][0] * compute_det_permutation(P);

	// free space and return
	free_matrix(A);
	free_matrix(tildeL);
	free_matrix(tildeP);
	free_matrix(Cprime);
	free_matrix(Lprime);
	free_matrix(Uprime);
	free_matrix(Pprime);
	return det_value;
}


/* Inverts the given lower triangular matrix L
 * and stores the inverse in invL.
 * Returns the determinant of the matrix.
 */
float inv_lower(Matrix L, Matrix invL)
{
	float det = 1.0; // stores the determinant of the matrix

	for (int i = 0; i < L.rows; i++) {
		for (int k = i; k >= 0; k--) {
			if (k == i) { // we have invL[i][i] * L[i][i] = 1
				det = det * L.elements[i][i]; // update determinant
				if (det == 0.0) // matrix not invertible
					return det;
				invL.elements[i][i] = 1.0 / L.elements[i][i]; // set the diagonal element
			}
			else { // Use equation \sum_{j=k}^i invL[i][j] * L[j][k] = 0
				invL.elements[i][k] = 0.0; // initialize
				for (int j = i; j > k; j--)
					invL.elements[i][k] = invL.elements[i][k] - 
							invL.elements[i][j] * L.elements[j][k];
				invL.elements[i][k] = invL.elements[i][k] / L.elements[k][k]; // final division
			}
		}
	}

	return det;
}


/* Inverts the given upper triangular matrix U 
 * and stores the inverse in invU.
 * Returns the determinant of the matrix.
 */
float inv_upper(Matrix U, Matrix invU)
{
	float det = 1.0; // stores the determinant of the matrix

	for (int i = 0; i < U.rows; i++) {
		for (int k = i; k < U.cols; k++) {
			if (k == i) { // we have invU[i][i] * U[i][i] = 1
				det = det * U.elements[i][i]; // update determinant
				if (det == 0.0) // matrix not invertible
					return det;
				invU.elements[i][i] = 1.0 / U.elements[i][i]; // set the diagonal element
			}
			else { // Use equation \sum_{j=i}^k invU[i][j] * U[j][k] = 0
				invU.elements[i][k] = 0.0; // initialize
				for (int j = i; j < k; j++)
					invU.elements[i][k] = invU.elements[i][k] - 
							invU.elements[i][j] * U.elements[j][k];
				invU.elements[i][k] = invU.elements[i][k] / U.elements[k][k]; // final division
			}
		}
	}

	return det;
}


/* Inverts the given permutation matrix P and stores the inverse in invP.
 */
void inv_permutation(Matrix P, Matrix invP)
{
	for (int i = 0; i < P.rows; i++)
		for (int j = 0; j < P.cols; j++)
			invP.elements[i][j] = P.elements[j][i]; // inverse is transpose of the matrix
}


/* Computes the inverse of matrix A and stores in invA.
 * Returns the determinant of A.
 */
float inv_matrix(Matrix A, Matrix invA)
{
	float det;
	Matrix L; // lower triangular
	Matrix invL; // inverse of L
	Matrix U; // upper triangular
	Matrix invU; // inverse of U
	Matrix P; // permutation
	Matrix invP; // inverse of P

	L = allocate_matrix(A.rows, A.cols);
	U = allocate_matrix(A.rows, A.cols);
	P = allocate_matrix(A.rows, A.cols);

	det = LUP_decompose(A, L, U, P);
	if (is_zero(det)) {

		// free space
		free_matrix(L);
		free_matrix(U);
		free_matrix(P);
		return det;
	}

	invL = allocate_matrix(A.rows, A.cols);
	invU = allocate_matrix(A.rows, A.cols);
	invP = allocate_matrix(A.rows, A.cols);

	inv_lower(L, invL);
	inv_upper(U, invU);
	inv_permutation(P, invP);

	multiply_matrix(invU, invL, &invA);
	multiply_matrix(invP, invA, &invA);

	// free space and return
	free_matrix(L);
	free_matrix(U);
	free_matrix(P);
	free_matrix(invL);
	free_matrix(invU);
	free_matrix(invP);
	return det;
}


