CGII/framework/include/cgv/math/lin_solve.h
2018-05-17 16:01:02 +02:00

343 lines
6.7 KiB
C++

#pragma once
#include <cgv/math/vec.h>
#include <cgv/math/mat.h>
#include <cgv/math/perm_mat.h>
#include <cgv/math/diag_mat.h>
#include <cgv/math/tri_diag_mat.h>
#include <cgv/math/up_tri_mat.h>
#include <cgv/math/low_tri_mat.h>
#include <cgv/math/lu.h>
#include <cgv/math/svd.h>
#include <cgv/math/qr.h>
namespace cgv {
namespace math {
///solves linear system ax=b
///a is an upper triangular matrix
template<typename T>
bool solve(const up_tri_mat<T>& a,const vec<T>&b, vec<T>&x)
{
assert(a.nrows() == a.ncols());
int N = a.nrows();
x.resize(N);
T sum;
for(int i = N-1; i >= 0;i--)
{
sum =0;
for(int j = i+1;j < N; j++)
sum += a(i,j)*x(j);
if(a(i,i) == 0)
return false;
x[i] = (b[i] - sum)/a(i,i);
}
return true;
}
///solves multiple linear systems ax=b
///a is an upper triangular matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const up_tri_mat<T>& a,const mat<T>&b,mat<T>&x)
{
assert(b.nrows() == a.ncols());
vec<T> xcol;
x.resize(b.nrows(),b.ncols());
for(unsigned i = 0; i < b.ncols();i++)
{
if(!solve(a,b.col(i),xcol))
return false;
x.set_col(i,xcol);
}
return true;
}
///solves linear system ax=b
///a is a lower triangular matrix
template<typename T>
bool solve(const low_tri_mat<T>& a, const vec<T>&b, vec<T>&x)
{
int N = a.nrows();
x.resize(N);
T sum;
for(int i = 0; i < N;i++)
{
sum =0;
for(int j = 0;j < i; j++)
sum += a(i,j)*x(j);
if(a(i,i) == 0)
return false;
x[i] = (b[i] - sum)/a(i,i);
}
return true;
}
///solves multiple linear systems ax=b
///a is a lower triangular matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const low_tri_mat<T>& a,const mat<T>&b,mat<T>&x)
{
assert(b.nrows() == a.ncols());
vec<T> xcol;
x.resize(b.nrows(),b.ncols());
for(unsigned i = 0; i < b.ncols();i++)
{
if(!solve(a,b.col(i),xcol))
return false;
x.set_col(i,xcol);
}
return true;
}
///solves linear system ax=b
///a is a diagonal matrix
template<typename T>
bool solve(const diag_mat<T>& a, const vec<T>&b, vec<T>&x)
{
int N = a.ncols();
x.resize(N);
for(int i = 0; i < N;i++)
{
if(a(i) == 0)
return false;
x(i) = (T)b(i)/a(i);
}
return true;
}
///solves linear system ax=b
///a is a tri diagonal matrix
template <typename T>
bool solve(const tri_diag_mat<T>& a, const vec<T>& b, vec<T>& x)
{
x.resize(b.dim());
int i;
vec<T> aa = a.band(-1);
vec<T> bb = a.band(0);
vec<T> cc = a.band(1);
vec<T> dd = b;
int n = b.dim();
if(bb(0) == 0)
return false;
cc(0) /= bb(0);
dd(0) /= bb(0);
for(i = 1; i < n; i++)
{
T id = (bb(i) - cc(i-1) * aa(i));
if(id == 0)
return false;
cc(i) /= id;
dd(i) = (dd(i) - dd(i-1) * aa(i))/id;
}
x(n - 1) = dd(n - 1);
for(i = n - 2; i >= 0; i--)
x(i) = dd(i) - cc(i) * x(i + 1);
return true;
}
///solves multiple linear systems ax=b
///a is a diagonal matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const diag_mat<T>& a, const mat<T>&b, mat<T>&x)
{
assert(b.nrows() == a.ncols());
vec<T> xcol;
x.resize(b.nrows(),b.ncols());
for(unsigned i = 0; i < b.ncols();i++)
{
if(!solve(a,b.col(i),xcol))
return false;
x.set_col(i,xcol);
}
return true;
}
///solves linear system ax=b
///a is a permutation matrix
template<typename T>
bool solve(const perm_mat &a, const vec<T> &b, vec<T>&x)
{
x.resize(a.nrows());
x=transpose(a)*b;
return true;
}
///solves multiple linear systems ax=b
///a is permutation matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const perm_mat& a,const mat<T>&b,mat<T>&x)
{
assert(a.nrows() == a.ncols());
vec<T> xcol;
x.resize(b.nrows(),b.ncols());
for(unsigned i = 0; i < b.ncols();i++)
{
if(!solve(a,b.col(i),xcol))
return false;
x.set_col(i,xcol);
}
return true;
}
///solve ax=b with lu decomposition
///a is a full storage matrix
template<typename T>
bool lu_solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
{
assert(a.nrows() == a.ncols());
x.resize(a.nrows());
vec<T> temp1,temp2;
low_tri_mat<T> L;
up_tri_mat<T> U;
perm_mat P;
if(!lu(a,P,L,U))
return false;
if(!solve(P,b,temp1))
return false;
if(!solve(L,temp1,temp2))
return false;
return solve(U,temp2,x);
}
///solve ax=b, standard solver for full storage matrix is lu_solve
///a is a full storage matrix
template<typename T>
bool solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
{
return lu_solve( a, b, x) ;
}
///solves multiple linear systems ax=b
///a is full storage matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool lu_solve(const mat<T>& a,const mat<T>&b,mat<T>&x)
{
assert(a.nrows() == a.ncols());
x.resize(b.nrows(),b.ncols());
mat<T> temp1,temp2;
low_tri_mat<T> L;
up_tri_mat<T> U;
perm_mat P;
if(!lu(a,P,L,U))
return false;
if(!solve(P,b,temp1))
return false;
if(!solve(L,temp1,temp2))
return false;
return solve(U,temp2,x);
}
///solves multiple linear systems ax=b with the svd solver
///a is full storage matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool solve(const mat<T>& a, const mat<T>&b, mat<T>&x)
{
return svd_solve( a, b, x) ;
}
///solve ax=b with qr decomposition
template<typename T>
bool qr_solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
{
x.resize(a.nrows());
vec<T> temp;
mat<T> q;
up_tri_mat<T> r;
if(!qr(a,q,r))
return false;
Atx(q,b,temp);
return solve(r,temp,x);
}
///solves multiple linear systems ax=b with qr solver
///a is full storage matrix
///x is the matrix of solution vectors (columns)
///b is the matrix of right-hand sides (columns)
template<typename T>
bool qr_solve(const mat<T>& a, const mat<T>&b, mat<T>&x)
{
assert(a.nrows() == a.ncols());
x.resize(b.nrows(),b.ncols());
mat<T> temp;
mat<T> q;
up_tri_mat<T> r;
if(!qr(a,q,r))
return false;
AtB(q,b,temp);
return solve(r,temp,x);
}
///solve ax=b with svd decomposition
template<typename T>
bool svd_solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
{
x.resize(a.nrows());
vec<T> temp;
mat<T> u,v;
diag_mat<T> d;
if(!svd(a,u,d,v))
return false;
Atx(u,b,temp);
if(!solve(d,temp,x))
return false;
x=v*x;
return true;
}
///solve ax=b with svd decomposition
template<typename T>
bool svd_solve(const mat<T>& a, const mat<T>&b, mat<T>&x)
{
assert(a.nrows() == a.ncols());
x.resize(b.nrows(),b.ncols());
mat<T> temp;
mat<T> u,v;
diag_mat<T> d;
if(!svd(a,u,d,v))
return false;
AtB(u,b,temp);
if(!solve(d,temp,x))
return false;
x=v*x;
return true;
}
}
}