Prev Next lu_solve.hpp Headings

Source: LuSolve
# ifndef CPPAD_LU_SOLVE_INCLUDED
# define CPPAD_LU_SOLVE_INCLUDED
# include <complex>
# include <vector>

// link exp for float and double cases
# include <cppad/std_math_unary.hpp>

# include <cppad/local/cppad_assert.hpp>
# include <cppad/check_simple_vector.hpp>
# include <cppad/check_numeric_type.hpp>
# include <cppad/lu_factor.hpp>
# include <cppad/lu_invert.hpp>

namespace CppAD { // BEGIN CppAD namespace

// LeqZero
template <typename Float>
inline bool LeqZero(const Float &x)
{    return x <= Float(0); }
inline bool LeqZero( const std::complex<double> &x )
{    return x == std::complex<double>(0); }
inline bool LeqZero( const std::complex<float> &x )
{    return x == std::complex<float>(0); }

// LuSolve
template <typename Float, typename FloatVector>
int LuSolve(
     size_t             n      ,
     size_t             m      , 
     const FloatVector &A      , 
     const FloatVector &B      , 
     FloatVector       &X      , 
     Float        &logdet      )
{    
     // check numeric type specifications
     CheckNumericType<Float>();

     // check simple vector class specifications
     CheckSimpleVector<Float, FloatVector>();

     size_t        p;       // index of pivot element (diagonal of L)
     int     signdet;       // sign of the determinant
     Float     pivot;       // pivot element

     // the value zero
     const Float zero(0);

     // pivot row and column order in the matrix
     std::vector<size_t> ip(n);
     std::vector<size_t> jp(n);

     // -------------------------------------------------------
     CPPAD_ASSERT_KNOWN(
          A.size() == n * n,
          "Error in LuSolve: A must have size equal to n * n"
     );
     CPPAD_ASSERT_KNOWN(
          B.size() == n * m,
          "Error in LuSolve: B must have size equal to n * m"
     );
     CPPAD_ASSERT_KNOWN(
          X.size() == n * m,
          "Error in LuSolve: X must have size equal to n * m"
     );
     // -------------------------------------------------------

     // copy A so that it does not change
     FloatVector Lu(A);

     // copy B so that it does not change
     X = B;

     // Lu factor the matrix A
     signdet = LuFactor(ip, jp, Lu);

     // compute the log of the determinant
     logdet  = Float(0);
     for(p = 0; p < n; p++)
     {    // pivot using the max absolute element
          pivot   = Lu[ ip[p] * n + jp[p] ];

          // check for determinant equal to zero
          if( pivot == zero )
          {    // abort the mission
               logdet = Float(0);
               return   0;
          }

          // update the determinant
          if( LeqZero ( pivot ) )
          {    logdet += log( - pivot );
               signdet = - signdet;
          }
          else logdet += log( pivot );

     }

     // solve the linear equations
     LuInvert(ip, jp, Lu, X);

     // return the sign factor for the determinant
     return signdet;
}
} // END CppAD namespace 
# endif

Input File: omh/lu_solve_hpp.omh