#pragma once
#include <cgv/math/vec.h>
#include <cgv/math/mat.h>
#include <cgv/math/perm_mat.h>
#include <cgv/math/diag_mat.h>
#include <cgv/math/tri_diag_mat.h>
#include <cgv/math/up_tri_mat.h>
#include <cgv/math/low_tri_mat.h>
#include <cgv/math/lu.h>
#include <cgv/math/svd.h>
#include <cgv/math/qr.h>


namespace cgv {
	namespace math {

///solves linear system ax=b
///a is an upper triangular matrix
template<typename T>
bool solve(const up_tri_mat<T>& a,const vec<T>&b, vec<T>&x) 
{
	assert(a.nrows() == a.ncols());

	int N = a.nrows();
	x.resize(N);
	T sum;
	for(int i = N-1; i >= 0;i--)
	{
		sum =0;
		for(int j = i+1;j < N; j++) 
			sum += a(i,j)*x(j);
		if(a(i,i) == 0)
			return false;
		x[i] = (b[i] - sum)/a(i,i);
	}
	return true;	
}

///solves multiple linear systems ax=b 
///a is an upper triangular matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const up_tri_mat<T>& a,const mat<T>&b,mat<T>&x) 
{
	assert(b.nrows() == a.ncols());
	vec<T> xcol;
	x.resize(b.nrows(),b.ncols());
	for(unsigned i = 0; i < b.ncols();i++)
	{
		if(!solve(a,b.col(i),xcol))
			return false;
		x.set_col(i,xcol);
	}
	return true;
}

///solves linear system ax=b
///a is a lower triangular matrix
template<typename T>
bool solve(const low_tri_mat<T>& a, const vec<T>&b, vec<T>&x) 
{

	int N = a.nrows();
	x.resize(N);
	T sum;
	for(int i = 0; i < N;i++)
	{
		sum =0;
		for(int j = 0;j < i; j++) 
			sum += a(i,j)*x(j);
		if(a(i,i) == 0)
			return false;
		x[i] = (b[i] - sum)/a(i,i);
	}
	return true;
	
}

///solves multiple linear systems ax=b 
///a is a lower triangular matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const low_tri_mat<T>& a,const mat<T>&b,mat<T>&x) 
{
	assert(b.nrows() == a.ncols());
	vec<T> xcol;
	x.resize(b.nrows(),b.ncols());
	for(unsigned i = 0; i < b.ncols();i++)
	{
		if(!solve(a,b.col(i),xcol))
			return false;
		x.set_col(i,xcol);
	}
	return true;
}

///solves linear system ax=b
///a is a diagonal matrix
template<typename T>
bool solve(const diag_mat<T>& a, const vec<T>&b, vec<T>&x) 
{
	int N = a.ncols();
	x.resize(N);
	for(int i = 0; i < N;i++)
	{
		if(a(i) == 0)
			return false;
		x(i) = (T)b(i)/a(i);
	}
	return true;
}

///solves linear system ax=b
///a is a tri diagonal matrix
template <typename T>
bool solve(const tri_diag_mat<T>& a, const vec<T>& b, vec<T>& x)
{
	x.resize(b.dim());
	int i;
	vec<T> aa = a.band(-1);
	vec<T> bb = a.band(0);
	vec<T> cc = a.band(1);
	vec<T> dd = b;
	
	int n = b.dim();
	
	if(bb(0) == 0)
		return false;
	cc(0) /= bb(0);
	dd(0) /= bb(0);
	for(i = 1; i < n; i++)
	{
		T id = (bb(i) - cc(i-1) * aa(i));
		if(id == 0)
			return false;
		cc(i) /= id;			
		dd(i) = (dd(i) - dd(i-1) * aa(i))/id;
	}
 
	x(n - 1) = dd(n - 1);
	for(i = n - 2; i >= 0; i--)
		x(i) = dd(i) - cc(i) * x(i + 1);
	return true;
}

///solves multiple linear systems ax=b 
///a is a diagonal matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const diag_mat<T>& a, const mat<T>&b, mat<T>&x) 
{
	assert(b.nrows() == a.ncols());
	vec<T> xcol;
	x.resize(b.nrows(),b.ncols());
	for(unsigned i = 0; i < b.ncols();i++)
	{
		if(!solve(a,b.col(i),xcol))
			return false;
		x.set_col(i,xcol);
	}
	return true;
}

///solves linear system ax=b
///a is a permutation matrix
template<typename T>
bool solve(const perm_mat &a, const vec<T> &b, vec<T>&x)
{

	x.resize(a.nrows());
	x=transpose(a)*b;
	return true;
}

///solves multiple linear systems ax=b 
///a is permutation matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const perm_mat& a,const mat<T>&b,mat<T>&x) 
{
	assert(a.nrows() == a.ncols());
	vec<T> xcol;
	x.resize(b.nrows(),b.ncols());
	for(unsigned i = 0; i < b.ncols();i++)
	{
		if(!solve(a,b.col(i),xcol))
			return false;
		x.set_col(i,xcol);
	}
	return true;
}




///solve ax=b with lu decomposition
///a is a full storage matrix
template<typename T>
bool lu_solve(const mat<T>& a, const vec<T>&b, vec<T>&x) 
{
	assert(a.nrows() == a.ncols());
	x.resize(a.nrows());
	vec<T> temp1,temp2;
	low_tri_mat<T> L;
	up_tri_mat<T> U;
	perm_mat P;
	if(!lu(a,P,L,U))
		return false;
	if(!solve(P,b,temp1))
		return false;
	if(!solve(L,temp1,temp2))
		return false;
	return solve(U,temp2,x);		
}

///solve ax=b, standard solver for full storage matrix is lu_solve
///a is a full storage matrix
template<typename T>
bool solve(const mat<T>& a, const vec<T>&b, vec<T>&x) 
{
	return lu_solve( a, b, x) ;
}

///solves multiple linear systems ax=b 
///a is full storage matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool lu_solve(const mat<T>& a,const mat<T>&b,mat<T>&x) 
{
	assert(a.nrows() == a.ncols());
	x.resize(b.nrows(),b.ncols());
	mat<T> temp1,temp2;
	low_tri_mat<T> L;
	up_tri_mat<T> U;
	perm_mat P;
	if(!lu(a,P,L,U))
		return false;
	if(!solve(P,b,temp1))
		return false;
	if(!solve(L,temp1,temp2))
		return false;
	return solve(U,temp2,x);
	
}

///solves multiple linear systems  ax=b with the svd solver 
///a is full storage matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const mat<T>& a, const mat<T>&b, mat<T>&x) 
{
	return svd_solve( a, b, x) ;
}



///solve ax=b with qr decomposition
template<typename T>
bool qr_solve(const mat<T>& a, const vec<T>&b, vec<T>&x) 
{
	x.resize(a.nrows());
	vec<T> temp;
	mat<T> q;
	up_tri_mat<T> r;
	if(!qr(a,q,r))
		return false;
	Atx(q,b,temp);
	return solve(r,temp,x);
}


///solves multiple linear systems  ax=b with qr solver 
///a is full storage matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool qr_solve(const mat<T>& a, const mat<T>&b, mat<T>&x) 
{
	assert(a.nrows() == a.ncols());
	x.resize(b.nrows(),b.ncols());
	mat<T> temp;
	mat<T> q;
	up_tri_mat<T> r;
	if(!qr(a,q,r))
		return false;
	AtB(q,b,temp);
	return solve(r,temp,x);
}



///solve ax=b with svd decomposition
template<typename T>
bool svd_solve(const mat<T>& a, const vec<T>&b, vec<T>&x) 
{
	x.resize(a.nrows());
	vec<T> temp;
	mat<T> u,v;
	diag_mat<T> d;
	if(!svd(a,u,d,v))
		return false;
	  
	Atx(u,b,temp);
	if(!solve(d,temp,x))
		return false;
	x=v*x;
	return true;
}

///solve ax=b with svd decomposition
template<typename T>
bool svd_solve(const mat<T>& a, const mat<T>&b, mat<T>&x) 
{
	assert(a.nrows() == a.ncols());
	x.resize(b.nrows(),b.ncols());
	
	mat<T> temp;
	mat<T> u,v;
	diag_mat<T> d;
	if(!svd(a,u,d,v))
		return false;  
	AtB(u,b,temp);
	if(!solve(d,temp,x))
		return false;
	x=v*x;
	return true;
}






}


}