IT++ Logo

ls_solve.cpp

Go to the documentation of this file.
00001 
00030 #ifndef _MSC_VER
00031 #  include <itpp/config.h>
00032 #else
00033 #  include <itpp/config_msvc.h>
00034 #endif
00035 
00036 #if defined(HAVE_LAPACK)
00037 #  include <itpp/base/algebra/lapack.h>
00038 #endif
00039 
00040 #include <itpp/base/algebra/ls_solve.h>
00041 
00042 
00043 namespace itpp {
00044 
00045   // ----------- ls_solve_chol -----------------------------------------------------------
00046 
00047 #if defined(HAVE_LAPACK)
00048 
00049   bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00050   {
00051     int n, lda, ldb, nrhs, info;
00052     n = lda = ldb = A.rows();
00053     nrhs = 1;
00054     char uplo='U';
00055 
00056     it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00057     it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00058 
00059     ivec ipiv(n);
00060     x = b;
00061     mat Chol = A;
00062 
00063     dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00064 
00065     return (info==0);
00066   }
00067 
00068 
00069   bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00070   {
00071     int n, lda, ldb, nrhs, info;
00072     n = lda = ldb = A.rows();
00073     nrhs = B.cols();
00074     char uplo='U';
00075 
00076     it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00077     it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00078 
00079     ivec ipiv(n);
00080     X = B;
00081     mat Chol = A;
00082 
00083     dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00084 
00085     return (info==0);
00086   }
00087 
00088   bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00089   {
00090     int n, lda, ldb, nrhs, info;
00091     n = lda = ldb = A.rows();
00092     nrhs = 1;
00093     char uplo='U';
00094 
00095     it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00096     it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00097 
00098     ivec ipiv(n);
00099     x = b;
00100     cmat Chol = A;
00101 
00102     zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00103 
00104     return (info==0);
00105   }
00106 
00107   bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00108   {
00109     int n, lda, ldb, nrhs, info;
00110     n = lda = ldb = A.rows();
00111     nrhs = B.cols();
00112     char uplo='U';
00113 
00114     it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00115     it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00116 
00117     ivec ipiv(n);
00118     X = B;
00119     cmat Chol = A;
00120 
00121     zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00122 
00123     return (info==0);
00124   }
00125 
00126 #else
00127 
00128   bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00129   {
00130     it_error("LAPACK library is needed to use ls_solve_chol() function");
00131     return false;
00132   }
00133 
00134   bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00135   {
00136     it_error("LAPACK library is needed to use ls_solve_chol() function");
00137     return false;
00138   }
00139 
00140   bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00141   {
00142     it_error("LAPACK library is needed to use ls_solve_chol() function");
00143     return false;
00144   }
00145 
00146   bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00147   {
00148     it_error("LAPACK library is needed to use ls_solve_chol() function");
00149     return false;
00150   }
00151 
00152 #endif // HAVE_LAPACK
00153 
00154   vec ls_solve_chol(const mat &A, const vec &b)
00155   {
00156     vec x;
00157     bool info;
00158     info = ls_solve_chol(A, b, x);
00159     it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00160     return x;
00161   }
00162 
00163   mat ls_solve_chol(const mat &A, const mat &B)
00164   {
00165     mat X;
00166     bool info;
00167     info = ls_solve_chol(A, B, X);
00168     it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00169     return X;
00170   }
00171 
00172   cvec ls_solve_chol(const cmat &A, const cvec &b)
00173   {
00174     cvec x;
00175     bool info;
00176     info = ls_solve_chol(A, b, x);
00177     it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00178     return x;
00179   }
00180 
00181   cmat ls_solve_chol(const cmat &A, const cmat &B)
00182   {
00183     cmat X;
00184     bool info;
00185     info = ls_solve_chol(A, B, X);
00186     it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00187     return X;
00188   }
00189 
00190 
00191   // --------- ls_solve ---------------------------------------------------------------
00192 #if defined(HAVE_LAPACK)
00193 
00194   bool ls_solve(const mat &A, const vec &b, vec &x)
00195   {
00196     int n, lda, ldb, nrhs, info;
00197     n = lda = ldb = A.rows();
00198     nrhs = 1;
00199 
00200     it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00201     it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00202 
00203     ivec ipiv(n);
00204     x = b;
00205     mat LU = A;
00206 
00207     dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00208 
00209     return (info==0);
00210   }
00211 
00212   bool ls_solve(const mat &A, const mat &B, mat &X)
00213   {
00214     int n, lda, ldb, nrhs, info;
00215     n = lda = ldb = A.rows();
00216     nrhs = B.cols();
00217 
00218     it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00219     it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00220 
00221     ivec ipiv(n);
00222     X = B;
00223     mat LU = A;
00224 
00225     dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00226 
00227     return (info==0);
00228   }
00229 
00230   bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00231   {
00232     int n, lda, ldb, nrhs, info;
00233     n = lda = ldb = A.rows();
00234     nrhs = 1;
00235 
00236     it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00237     it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00238 
00239     ivec ipiv(n);
00240     x = b;
00241     cmat LU = A;
00242 
00243     zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00244 
00245     return (info==0);
00246   }
00247 
00248   bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00249   {
00250     int n, lda, ldb, nrhs, info;
00251     n = lda = ldb = A.rows();
00252     nrhs = B.cols();
00253 
00254     it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00255     it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00256 
00257     ivec ipiv(n);
00258     X = B;
00259     cmat LU = A;
00260 
00261     zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00262 
00263     return (info==0);
00264   }
00265 
00266 #else
00267 
00268   bool ls_solve(const mat &A, const vec &b, vec &x)
00269   {
00270     it_error("LAPACK library is needed to use ls_solve() function");
00271     return false;
00272   }
00273 
00274   bool ls_solve(const mat &A, const mat &B, mat &X)
00275   {
00276     it_error("LAPACK library is needed to use ls_solve() function");
00277     return false;
00278   }
00279 
00280   bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00281   {
00282     it_error("LAPACK library is needed to use ls_solve() function");
00283     return false;
00284   }
00285 
00286   bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00287   {
00288     it_error("LAPACK library is needed to use ls_solve() function");
00289     return false;
00290   }
00291 
00292 #endif // HAVE_LAPACK
00293 
00294   vec ls_solve(const mat &A, const vec &b)
00295   {
00296     vec x;
00297     bool info;
00298     info = ls_solve(A, b, x);
00299     it_assert_debug(info, "ls_solve: Failed solving the system");
00300     return x;
00301   }
00302 
00303   mat ls_solve(const mat &A, const mat &B)
00304   {
00305     mat X;
00306     bool info;
00307     info = ls_solve(A, B, X);
00308     it_assert_debug(info, "ls_solve: Failed solving the system");
00309     return X;
00310   }
00311 
00312   cvec ls_solve(const cmat &A, const cvec &b)
00313   {
00314     cvec x;
00315     bool info;
00316     info = ls_solve(A, b, x);
00317     it_assert_debug(info, "ls_solve: Failed solving the system");
00318     return x;
00319   }
00320 
00321   cmat ls_solve(const cmat &A, const cmat &B)
00322   {
00323     cmat X;
00324     bool info;
00325     info = ls_solve(A, B, X);
00326     it_assert_debug(info, "ls_solve: Failed solving the system");
00327     return X;
00328   }
00329 
00330 
00331   // ----------------- ls_solve_od ------------------------------------------------------------------
00332 #if defined(HAVE_LAPACK)
00333 
00334   bool ls_solve_od(const mat &A, const vec &b, vec &x)
00335   {
00336     int m, n, lda, ldb, nrhs, lwork, info;
00337     char trans='N';
00338     m = lda = ldb = A.rows();
00339     n = A.cols();
00340     nrhs = 1;
00341     lwork = n + std::max(m,nrhs);
00342 
00343     it_assert_debug(m >= n, "The system is under-determined!");
00344     it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00345 
00346     vec work(lwork);
00347     x = b;
00348     mat QR = A;
00349 
00350     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00351     x.set_size(n, true);
00352 
00353     return (info==0);
00354   }
00355 
00356   bool ls_solve_od(const mat &A, const mat &B, mat &X)
00357   {
00358     int m, n, lda, ldb, nrhs, lwork, info;
00359     char trans='N';
00360     m = lda = ldb = A.rows();
00361     n = A.cols();
00362     nrhs = B.cols();
00363     lwork = n + std::max(m,nrhs);
00364 
00365     it_assert_debug(m >= n, "The system is under-determined!");
00366     it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00367 
00368     vec work(lwork);
00369     X = B;
00370     mat QR = A;
00371 
00372     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00373     X.set_size(n, nrhs, true);
00374 
00375     return (info==0);
00376   }
00377 
00378   bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00379   {
00380     int m, n, lda, ldb, nrhs, lwork, info;
00381     char trans='N';
00382     m = lda = ldb = A.rows();
00383     n = A.cols();
00384     nrhs = 1;
00385     lwork = n + std::max(m,nrhs);
00386 
00387     it_assert_debug(m >= n, "The system is under-determined!");
00388     it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00389 
00390     cvec work(lwork);
00391     x = b;
00392     cmat QR = A;
00393 
00394     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00395     x.set_size(n, true);
00396 
00397     return (info==0);
00398   }
00399 
00400   bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00401   {
00402     int m, n, lda, ldb, nrhs, lwork, info;
00403     char trans='N';
00404     m = lda = ldb = A.rows();
00405     n = A.cols();
00406     nrhs = B.cols();
00407     lwork = n + std::max(m,nrhs);
00408 
00409     it_assert_debug(m >= n, "The system is under-determined!");
00410     it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00411 
00412     cvec work(lwork);
00413     X = B;
00414     cmat QR = A;
00415 
00416     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00417     X.set_size(n, nrhs, true);
00418 
00419     return (info==0);
00420   }
00421 
00422 #else
00423 
00424   bool ls_solve_od(const mat &A, const vec &b, vec &x)
00425   {
00426     it_error("LAPACK library is needed to use ls_solve_od() function");
00427     return false;
00428   }
00429 
00430   bool ls_solve_od(const mat &A, const mat &B, mat &X)
00431   {
00432     it_error("LAPACK library is needed to use ls_solve_od() function");
00433     return false;
00434   }
00435 
00436   bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00437   {
00438     it_error("LAPACK library is needed to use ls_solve_od() function");
00439     return false;
00440   }
00441 
00442   bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00443   {
00444     it_error("LAPACK library is needed to use ls_solve_od() function");
00445     return false;
00446   }
00447 
00448 #endif // HAVE_LAPACK
00449 
00450   vec ls_solve_od(const mat &A, const vec &b)
00451   {
00452     vec x;
00453     bool info;
00454     info = ls_solve_od(A, b, x);
00455     it_assert_debug(info, "ls_solve_od: Failed solving the system");
00456     return x;
00457   }
00458 
00459   mat ls_solve_od(const mat &A, const mat &B)
00460   {
00461     mat X;
00462     bool info;
00463     info = ls_solve_od(A, B, X);
00464     it_assert_debug(info, "ls_solve_od: Failed solving the system");
00465     return X;
00466   }
00467 
00468   cvec ls_solve_od(const cmat &A, const cvec &b)
00469   {
00470     cvec x;
00471     bool info;
00472     info = ls_solve_od(A, b, x);
00473     it_assert_debug(info, "ls_solve_od: Failed solving the system");
00474     return x;
00475   }
00476 
00477   cmat ls_solve_od(const cmat &A, const cmat &B)
00478   {
00479     cmat X;
00480     bool info;
00481     info = ls_solve_od(A, B, X);
00482     it_assert_debug(info, "ls_solve_od: Failed solving the system");
00483     return X;
00484   }
00485 
00486   // ------------------- ls_solve_ud -----------------------------------------------------------
00487 #if defined(HAVE_LAPACK)
00488 
00489   bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00490   {
00491     int m, n, lda, ldb, nrhs, lwork, info;
00492     char trans='N';
00493     m = lda = A.rows();
00494     n = A.cols();
00495     ldb = n;
00496     nrhs = 1;
00497     lwork = m + std::max(n,nrhs);
00498 
00499     it_assert_debug(m < n, "The system is over-determined!");
00500     it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00501 
00502     vec work(lwork);
00503     x = b;
00504     x.set_size(n, true);
00505     mat QR = A;
00506 
00507     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00508 
00509     return (info==0);
00510   }
00511 
00512   bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00513   {
00514     int m, n, lda, ldb, nrhs, lwork, info;
00515     char trans='N';
00516     m = lda = A.rows();
00517     n = A.cols();
00518     ldb = n;
00519     nrhs = B.cols();
00520     lwork = m + std::max(n,nrhs);
00521 
00522     it_assert_debug(m < n, "The system is over-determined!");
00523     it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00524 
00525     vec work(lwork);
00526     X = B;
00527     X.set_size(n, std::max(m, nrhs), true);
00528     mat QR = A;
00529 
00530     dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00531     X.set_size(n, nrhs, true);
00532 
00533     return (info==0);
00534   }
00535 
00536   bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00537   {
00538     int m, n, lda, ldb, nrhs, lwork, info;
00539     char trans='N';
00540     m = lda = A.rows();
00541     n = A.cols();
00542     ldb = n;
00543     nrhs = 1;
00544     lwork = m + std::max(n,nrhs);
00545 
00546     it_assert_debug(m < n, "The system is over-determined!");
00547     it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00548 
00549     cvec work(lwork);
00550     x = b;
00551     x.set_size(n, true);
00552     cmat QR = A;
00553 
00554     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00555 
00556     return (info==0);
00557   }
00558 
00559   bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00560   {
00561     int m, n, lda, ldb, nrhs, lwork, info;
00562     char trans='N';
00563     m = lda = A.rows();
00564     n = A.cols();
00565     ldb = n;
00566     nrhs = B.cols();
00567     lwork = m + std::max(n,nrhs);
00568 
00569     it_assert_debug(m < n, "The system is over-determined!");
00570     it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00571 
00572     cvec work(lwork);
00573     X = B;
00574     X.set_size(n, std::max(m, nrhs), true);
00575     cmat QR = A;
00576 
00577     zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00578     X.set_size(n, nrhs, true);
00579 
00580     return (info==0);
00581   }
00582 
00583 #else
00584 
00585   bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00586   {
00587     it_error("LAPACK library is needed to use ls_solve_ud() function");
00588     return false;
00589   }
00590 
00591   bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00592   {
00593     it_error("LAPACK library is needed to use ls_solve_ud() function");
00594     return false;
00595   }
00596 
00597   bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00598   {
00599     it_error("LAPACK library is needed to use ls_solve_ud() function");
00600     return false;
00601   }
00602 
00603   bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00604   {
00605     it_error("LAPACK library is needed to use ls_solve_ud() function");
00606     return false;
00607   }
00608 
00609 #endif // HAVE_LAPACK
00610 
00611 
00612   vec ls_solve_ud(const mat &A, const vec &b)
00613   {
00614     vec x;
00615     bool info;
00616     info = ls_solve_ud(A, b, x);
00617     it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00618     return x;
00619   }
00620 
00621   mat ls_solve_ud(const mat &A, const mat &B)
00622   {
00623     mat X;
00624     bool info;
00625     info = ls_solve_ud(A, B, X);
00626     it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00627     return X;
00628   }
00629 
00630   cvec ls_solve_ud(const cmat &A, const cvec &b)
00631   {
00632     cvec x;
00633     bool info;
00634     info = ls_solve_ud(A, b, x);
00635     it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00636     return x;
00637   }
00638 
00639   cmat ls_solve_ud(const cmat &A, const cmat &B)
00640   {
00641     cmat X;
00642     bool info;
00643     info = ls_solve_ud(A, B, X);
00644     it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00645     return X;
00646   }
00647 
00648 
00649   // ---------------------- backslash -----------------------------------------
00650 
00651   bool backslash(const mat &A, const vec &b, vec &x)
00652   {
00653     int m=A.rows(), n=A.cols();
00654     bool info;
00655 
00656     if (m == n)
00657       info = ls_solve(A,b,x);
00658     else if (m > n)
00659       info = ls_solve_od(A,b,x);
00660     else
00661       info = ls_solve_ud(A,b,x);
00662 
00663     return info;
00664   }
00665 
00666 
00667   vec backslash(const mat &A, const vec &b)
00668   {
00669     vec x;
00670     bool info;
00671     info = backslash(A, b, x);
00672     it_assert_debug(info, "backslash(): solution was not found");
00673     return x;
00674   }
00675 
00676 
00677   bool backslash(const mat &A, const mat &B, mat &X)
00678   {
00679     int m=A.rows(), n=A.cols();
00680     bool info;
00681 
00682     if (m == n)
00683       info = ls_solve(A, B, X);
00684     else if (m > n)
00685       info = ls_solve_od(A, B, X);
00686     else
00687       info = ls_solve_ud(A, B, X);
00688 
00689     return info;
00690   }
00691 
00692 
00693   mat backslash(const mat &A, const mat &B)
00694   {
00695     mat X;
00696     bool info;
00697     info = backslash(A, B, X);
00698     it_assert_debug(info, "backslash(): solution was not found");
00699     return X;
00700   }
00701 
00702 
00703   bool backslash(const cmat &A, const cvec &b, cvec &x)
00704   {
00705     int m=A.rows(), n=A.cols();
00706     bool info;
00707 
00708     if (m == n)
00709       info = ls_solve(A,b,x);
00710     else if (m > n)
00711       info = ls_solve_od(A,b,x);
00712     else
00713       info = ls_solve_ud(A,b,x);
00714 
00715     return info;
00716   }
00717 
00718 
00719   cvec backslash(const cmat &A, const cvec &b)
00720   {
00721     cvec x;
00722     bool info;
00723     info = backslash(A, b, x);
00724     it_assert_debug(info, "backslash(): solution was not found");
00725     return x;
00726   }
00727 
00728 
00729   bool backslash(const cmat &A, const cmat &B, cmat &X)
00730   {
00731     int m=A.rows(), n=A.cols();
00732     bool info;
00733 
00734     if (m == n)
00735       info = ls_solve(A, B, X);
00736     else if (m > n)
00737       info = ls_solve_od(A, B, X);
00738     else
00739       info = ls_solve_ud(A, B, X);
00740 
00741     return info;
00742   }
00743 
00744   cmat backslash(const cmat &A, const cmat &B)
00745   {
00746     cmat X;
00747     bool info;
00748     info = backslash(A, B, X);
00749     it_assert_debug(info, "backslash(): solution was not found");
00750     return X;
00751   }
00752 
00753 
00754   // --------------------------------------------------------------------------
00755 
00756   vec forward_substitution(const mat &L, const vec &b)
00757   {
00758     int n = L.rows();
00759     vec x(n);
00760 
00761     forward_substitution(L, b, x);
00762 
00763     return x;
00764   }
00765 
00766   void forward_substitution(const mat &L, const vec &b, vec &x)
00767   {
00768     it_assert( L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(),
00769                "forward_substitution: dimension mismatch" );
00770     int n = L.rows(), i, j;
00771     double temp;
00772 
00773     x(0)=b(0)/L(0,0);
00774     for (i=1;i<n;i++) {
00775       // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow.
00776       //i_pos=i*L._row_offset();
00777       temp=0;
00778       for (j=0; j<i; j++) {
00779         temp += L._elem(i,j) * x(j);
00780         //temp+=L._data()[i_pos+j]*x(j);
00781       }
00782       x(i) = (b(i)-temp)/L._elem(i,i);
00783       //x(i)=(b(i)-temp)/L._data()[i_pos+i];
00784     }
00785   }
00786 
00787   vec forward_substitution(const mat &L, int p, const vec &b)
00788   {
00789     int n = L.rows();
00790     vec x(n);
00791 
00792     forward_substitution(L, p, b, x);
00793 
00794     return x;
00795   }
00796 
00797   void forward_substitution(const mat &L, int p, const vec &b, vec &x)
00798   {
00799     it_assert( L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows()/2,
00800                "forward_substitution: dimension mismatch");
00801     int n = L.rows(), i, j;
00802 
00803     x=b;
00804 
00805     for (j=0;j<n;j++) {
00806       x(j)/=L(j,j);
00807       for (i=j+1;i<std::min(j+p+1,n);i++) {
00808         x(i)-=L(i,j)*x(j);
00809       }
00810     }
00811   }
00812 
00813   vec backward_substitution(const mat &U, const vec &b)
00814   {
00815     vec x(U.rows());
00816     backward_substitution(U, b, x);
00817 
00818     return x;
00819   }
00820 
00821   void backward_substitution(const mat &U, const vec &b, vec &x)
00822   {
00823     it_assert( U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(),
00824                "backward_substitution: dimension mismatch" );
00825     int n = U.rows(), i, j;
00826     double temp;
00827 
00828     x(n-1)=b(n-1)/U(n-1,n-1);
00829     for (i=n-2; i>=0; i--) {
00830       // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow.
00831       temp=0;
00832       //i_pos=i*U._row_offset();
00833       for (j=i+1; j<n; j++) {
00834         temp += U._elem(i,j) * x(j);
00835         //temp+=U._data()[i_pos+j]*x(j);
00836       }
00837       x(i) = (b(i)-temp)/U._elem(i,i);
00838       //x(i)=(b(i)-temp)/U._data()[i_pos+i];
00839     }
00840   }
00841 
00842   vec backward_substitution(const mat &U, int q, const vec &b)
00843   {
00844     vec x(U.rows());
00845     backward_substitution(U, q, b, x);
00846 
00847     return x;
00848   }
00849 
00850   void backward_substitution(const mat &U, int q, const vec &b, vec &x)
00851   {
00852     it_assert( U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows()/2,
00853                "backward_substitution: dimension mismatch" );
00854     int n = U.rows(), i, j;
00855 
00856     x=b;
00857 
00858     for (j=n-1; j>=0; j--) {
00859       x(j) /= U(j,j);
00860       for (i=std::max(0,j-q); i<j; i++) {
00861         x(i)-=U(i,j)*x(j);
00862       }
00863     }
00864   }
00865 
00866 } // namespace itpp
SourceForge Logo

Generated on Sat Apr 19 10:41:54 2008 for IT++ by Doxygen 1.5.5