#pragma once #include #include #include #include #include #include #include #include #include #include namespace cgv { namespace math { ///solves linear system ax=b ///a is an upper triangular matrix template bool solve(const up_tri_mat& a,const vec&b, vec&x) { assert(a.nrows() == a.ncols()); int N = a.nrows(); x.resize(N); T sum; for(int i = N-1; i >= 0;i--) { sum =0; for(int j = i+1;j < N; j++) sum += a(i,j)*x(j); if(a(i,i) == 0) return false; x[i] = (b[i] - sum)/a(i,i); } return true; } ///solves multiple linear systems ax=b ///a is an upper triangular matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool solve(const up_tri_mat& a,const mat&b,mat&x) { assert(b.nrows() == a.ncols()); vec xcol; x.resize(b.nrows(),b.ncols()); for(unsigned i = 0; i < b.ncols();i++) { if(!solve(a,b.col(i),xcol)) return false; x.set_col(i,xcol); } return true; } ///solves linear system ax=b ///a is a lower triangular matrix template bool solve(const low_tri_mat& a, const vec&b, vec&x) { int N = a.nrows(); x.resize(N); T sum; for(int i = 0; i < N;i++) { sum =0; for(int j = 0;j < i; j++) sum += a(i,j)*x(j); if(a(i,i) == 0) return false; x[i] = (b[i] - sum)/a(i,i); } return true; } ///solves multiple linear systems ax=b ///a is a lower triangular matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool solve(const low_tri_mat& a,const mat&b,mat&x) { assert(b.nrows() == a.ncols()); vec xcol; x.resize(b.nrows(),b.ncols()); for(unsigned i = 0; i < b.ncols();i++) { if(!solve(a,b.col(i),xcol)) return false; x.set_col(i,xcol); } return true; } ///solves linear system ax=b ///a is a diagonal matrix template bool solve(const diag_mat& a, const vec&b, vec&x) { int N = a.ncols(); x.resize(N); for(int i = 0; i < N;i++) { if(a(i) == 0) return false; x(i) = (T)b(i)/a(i); } return true; } ///solves linear system ax=b ///a is a tri diagonal matrix template bool solve(const tri_diag_mat& a, const vec& b, vec& x) { x.resize(b.dim()); int i; vec aa = a.band(-1); vec bb = a.band(0); vec cc = a.band(1); vec dd = b; int n = b.dim(); if(bb(0) == 0) return false; cc(0) /= bb(0); dd(0) /= bb(0); for(i = 1; i < n; i++) { T id = (bb(i) - cc(i-1) * aa(i)); if(id == 0) return false; cc(i) /= id; dd(i) = (dd(i) - dd(i-1) * aa(i))/id; } x(n - 1) = dd(n - 1); for(i = n - 2; i >= 0; i--) x(i) = dd(i) - cc(i) * x(i + 1); return true; } ///solves multiple linear systems ax=b ///a is a diagonal matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool solve(const diag_mat& a, const mat&b, mat&x) { assert(b.nrows() == a.ncols()); vec xcol; x.resize(b.nrows(),b.ncols()); for(unsigned i = 0; i < b.ncols();i++) { if(!solve(a,b.col(i),xcol)) return false; x.set_col(i,xcol); } return true; } ///solves linear system ax=b ///a is a permutation matrix template bool solve(const perm_mat &a, const vec &b, vec&x) { x.resize(a.nrows()); x=transpose(a)*b; return true; } ///solves multiple linear systems ax=b ///a is permutation matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool solve(const perm_mat& a,const mat&b,mat&x) { assert(a.nrows() == a.ncols()); vec xcol; x.resize(b.nrows(),b.ncols()); for(unsigned i = 0; i < b.ncols();i++) { if(!solve(a,b.col(i),xcol)) return false; x.set_col(i,xcol); } return true; } ///solve ax=b with lu decomposition ///a is a full storage matrix template bool lu_solve(const mat& a, const vec&b, vec&x) { assert(a.nrows() == a.ncols()); x.resize(a.nrows()); vec temp1,temp2; low_tri_mat L; up_tri_mat U; perm_mat P; if(!lu(a,P,L,U)) return false; if(!solve(P,b,temp1)) return false; if(!solve(L,temp1,temp2)) return false; return solve(U,temp2,x); } ///solve ax=b, standard solver for full storage matrix is lu_solve ///a is a full storage matrix template bool solve(const mat& a, const vec&b, vec&x) { return lu_solve( a, b, x) ; } ///solves multiple linear systems ax=b ///a is full storage matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool lu_solve(const mat& a,const mat&b,mat&x) { assert(a.nrows() == a.ncols()); x.resize(b.nrows(),b.ncols()); mat temp1,temp2; low_tri_mat L; up_tri_mat U; perm_mat P; if(!lu(a,P,L,U)) return false; if(!solve(P,b,temp1)) return false; if(!solve(L,temp1,temp2)) return false; return solve(U,temp2,x); } ///solves multiple linear systems ax=b with the svd solver ///a is full storage matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool solve(const mat& a, const mat&b, mat&x) { return svd_solve( a, b, x) ; } ///solve ax=b with qr decomposition template bool qr_solve(const mat& a, const vec&b, vec&x) { x.resize(a.nrows()); vec temp; mat q; up_tri_mat r; if(!qr(a,q,r)) return false; Atx(q,b,temp); return solve(r,temp,x); } ///solves multiple linear systems ax=b with qr solver ///a is full storage matrix ///x is the matrix of solution vectors (columns) ///b is the matrix of right-hand sides (columns) template bool qr_solve(const mat& a, const mat&b, mat&x) { assert(a.nrows() == a.ncols()); x.resize(b.nrows(),b.ncols()); mat temp; mat q; up_tri_mat r; if(!qr(a,q,r)) return false; AtB(q,b,temp); return solve(r,temp,x); } ///solve ax=b with svd decomposition template bool svd_solve(const mat& a, const vec&b, vec&x) { x.resize(a.nrows()); vec temp; mat u,v; diag_mat d; if(!svd(a,u,d,v)) return false; Atx(u,b,temp); if(!solve(d,temp,x)) return false; x=v*x; return true; } ///solve ax=b with svd decomposition template bool svd_solve(const mat& a, const mat&b, mat&x) { assert(a.nrows() == a.ncols()); x.resize(b.nrows(),b.ncols()); mat temp; mat u,v; diag_mat d; if(!svd(a,u,d,v)) return false; AtB(u,b,temp); if(!solve(d,temp,x)) return false; x=v*x; return true; } } }