$treeview $search $mathjax
Eigen
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CONJUGATE_GRADIENT_H 00011 #define EIGEN_CONJUGATE_GRADIENT_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner> 00027 EIGEN_DONT_INLINE 00028 void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, 00029 const Preconditioner& precond, int& iters, 00030 typename Dest::RealScalar& tol_error) 00031 { 00032 using std::sqrt; 00033 using std::abs; 00034 typedef typename Dest::RealScalar RealScalar; 00035 typedef typename Dest::Scalar Scalar; 00036 typedef Matrix<Scalar,Dynamic,1> VectorType; 00037 00038 RealScalar tol = tol_error; 00039 int maxIters = iters; 00040 00041 int n = mat.cols(); 00042 00043 VectorType residual = rhs - mat * x; //initial residual 00044 00045 RealScalar rhsNorm2 = rhs.squaredNorm(); 00046 if(rhsNorm2 == 0) 00047 { 00048 x.setZero(); 00049 iters = 0; 00050 tol_error = 0; 00051 return; 00052 } 00053 RealScalar threshold = tol*tol*rhsNorm2; 00054 RealScalar residualNorm2 = residual.squaredNorm(); 00055 if (residualNorm2 < threshold) 00056 { 00057 iters = 0; 00058 tol_error = sqrt(residualNorm2 / rhsNorm2); 00059 return; 00060 } 00061 00062 VectorType p(n); 00063 p = precond.solve(residual); //initial search direction 00064 00065 VectorType z(n), tmp(n); 00066 RealScalar absNew = numext::real(residual.dot(p)); // the square of the absolute value of r scaled by invM 00067 int i = 0; 00068 while(i < maxIters) 00069 { 00070 tmp.noalias() = mat * p; // the bottleneck of the algorithm 00071 00072 Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir 00073 x += alpha * p; // update solution 00074 residual -= alpha * tmp; // update residue 00075 00076 residualNorm2 = residual.squaredNorm(); 00077 if(residualNorm2 < threshold) 00078 break; 00079 00080 z = precond.solve(residual); // approximately solve for "A z = residual" 00081 00082 RealScalar absOld = absNew; 00083 absNew = numext::real(residual.dot(z)); // update the absolute value of r 00084 RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction 00085 p = z + beta * p; // update search direction 00086 i++; 00087 } 00088 tol_error = sqrt(residualNorm2 / rhsNorm2); 00089 iters = i; 00090 } 00091 00092 } 00093 00094 template< typename _MatrixType, int _UpLo=Lower, 00095 typename _Preconditioner = DiagonalPreconditioner<typename _MatrixType::Scalar> > 00096 class ConjugateGradient; 00097 00098 namespace internal { 00099 00100 template< typename _MatrixType, int _UpLo, typename _Preconditioner> 00101 struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > 00102 { 00103 typedef _MatrixType MatrixType; 00104 typedef _Preconditioner Preconditioner; 00105 }; 00106 00107 } 00108 00144 template< typename _MatrixType, int _UpLo, typename _Preconditioner> 00145 class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > 00146 { 00147 typedef IterativeSolverBase<ConjugateGradient> Base; 00148 using Base::mp_matrix; 00149 using Base::m_error; 00150 using Base::m_iterations; 00151 using Base::m_info; 00152 using Base::m_isInitialized; 00153 public: 00154 typedef _MatrixType MatrixType; 00155 typedef typename MatrixType::Scalar Scalar; 00156 typedef typename MatrixType::Index Index; 00157 typedef typename MatrixType::RealScalar RealScalar; 00158 typedef _Preconditioner Preconditioner; 00159 00160 enum { 00161 UpLo = _UpLo 00162 }; 00163 00164 public: 00165 00167 ConjugateGradient() : Base() {} 00168 00179 ConjugateGradient(const MatrixType& A) : Base(A) {} 00180 00181 ~ConjugateGradient() {} 00182 00188 template<typename Rhs,typename Guess> 00189 inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess> 00190 solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const 00191 { 00192 eigen_assert(m_isInitialized && "ConjugateGradient is not initialized."); 00193 eigen_assert(Base::rows()==b.rows() 00194 && "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b"); 00195 return internal::solve_retval_with_guess 00196 <ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0); 00197 } 00198 00200 template<typename Rhs,typename Dest> 00201 void _solveWithGuess(const Rhs& b, Dest& x) const 00202 { 00203 typedef typename internal::conditional<UpLo==(Lower|Upper), 00204 const MatrixType&, 00205 SparseSelfAdjointView<const MatrixType, UpLo> 00206 >::type MatrixWrapperType; 00207 m_iterations = Base::maxIterations(); 00208 m_error = Base::m_tolerance; 00209 00210 for(int j=0; j<b.cols(); ++j) 00211 { 00212 m_iterations = Base::maxIterations(); 00213 m_error = Base::m_tolerance; 00214 00215 typename Dest::ColXpr xj(x,j); 00216 internal::conjugate_gradient(MatrixWrapperType(*mp_matrix), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error); 00217 } 00218 00219 m_isInitialized = true; 00220 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; 00221 } 00222 00224 template<typename Rhs,typename Dest> 00225 void _solve(const Rhs& b, Dest& x) const 00226 { 00227 x.setZero(); 00228 _solveWithGuess(b,x); 00229 } 00230 00231 protected: 00232 00233 }; 00234 00235 00236 namespace internal { 00237 00238 template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs> 00239 struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> 00240 : solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> 00241 { 00242 typedef ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> Dec; 00243 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs) 00244 00245 template<typename Dest> void evalTo(Dest& dst) const 00246 { 00247 dec()._solve(rhs(),dst); 00248 } 00249 }; 00250 00251 } // end namespace internal 00252 00253 } // end namespace Eigen 00254 00255 #endif // EIGEN_CONJUGATE_GRADIENT_H