gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm]

Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm.hpp>

List of all members.

Static Public Member Functions

template<typename eT >
static void apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))


Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >

Partial emulation of ATLAS/BLAS gemm(), using caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 26 of file gemm.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename eT >
static void gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< eT > &  C,
const Mat< eT > &  A,
const Mat< eT > &  B,
const eT  alpha = eT(1),
const eT  beta = eT(0) 
) [inline, static]

Definition at line 35 of file gemm.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().

00042     {
00043     arma_extra_debug_sigprint();
00044 
00045     const u32 A_n_rows = A.n_rows;
00046     const u32 A_n_cols = A.n_cols;
00047     
00048     const u32 B_n_rows = B.n_rows;
00049     const u32 B_n_cols = B.n_cols;
00050     
00051     if( (do_trans_A == false) && (do_trans_B == false) )
00052       {
00053       arma_aligned podarray<eT> tmp(A_n_cols);
00054       eT* A_rowdata = tmp.memptr();
00055       
00056       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00057         {
00058         
00059         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00060           {
00061           A_rowdata[col_A] = A.at(row_A,col_A);
00062           }
00063         
00064         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00065           {
00066           const eT* B_coldata = B.colptr(col_B);
00067           
00068           eT acc = eT(0);
00069           for(u32 i=0; i < B_n_rows; ++i)
00070             {
00071             acc += A_rowdata[i] * B_coldata[i];
00072             }
00073         
00074           if( (use_alpha == false) && (use_beta == false) )
00075             {
00076             C.at(row_A,col_B) = acc;
00077             }
00078           else
00079           if( (use_alpha == true) && (use_beta == false) )
00080             {
00081             C.at(row_A,col_B) = alpha * acc;
00082             }
00083           else
00084           if( (use_alpha == false) && (use_beta == true) )
00085             {
00086             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00087             }
00088           else
00089           if( (use_alpha == true) && (use_beta == true) )
00090             {
00091             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00092             }
00093           
00094           }
00095         }
00096       }
00097     else
00098     if( (do_trans_A == true) && (do_trans_B == false) )
00099       {
00100       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00101         {
00102         // col_A is interpreted as row_A when storing the results in matrix C
00103         
00104         const eT* A_coldata = A.colptr(col_A);
00105         
00106         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00107           {
00108           const eT* B_coldata = B.colptr(col_B);
00109           
00110           eT acc = eT(0);
00111           for(u32 i=0; i < B_n_rows; ++i)
00112             {
00113             acc += A_coldata[i] * B_coldata[i];
00114             }
00115         
00116           if( (use_alpha == false) && (use_beta == false) )
00117             {
00118             C.at(col_A,col_B) = acc;
00119             }
00120           else
00121           if( (use_alpha == true) && (use_beta == false) )
00122             {
00123             C.at(col_A,col_B) = alpha * acc;
00124             }
00125           else
00126           if( (use_alpha == false) && (use_beta == true) )
00127             {
00128             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00129             }
00130           else
00131           if( (use_alpha == true) && (use_beta == true) )
00132             {
00133             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00134             }
00135           
00136           }
00137         }
00138       }
00139     else
00140     if( (do_trans_A == false) && (do_trans_B == true) )
00141       {
00142       Mat<eT> B_tmp = trans(B);
00143       gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00144       }
00145     else
00146     if( (do_trans_A == true) && (do_trans_B == true) )
00147       {
00148       // mat B_tmp = trans(B);
00149       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00150       
00151       
00152       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00153       // transpose operations are not needed
00154       
00155       arma_aligned podarray<eT> tmp(B.n_cols);
00156       eT* B_rowdata = tmp.memptr();
00157       
00158       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00159         {
00160         
00161         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00162           {
00163           B_rowdata[col_B] = B.at(row_B,col_B);
00164           }
00165         
00166         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00167           {
00168           const eT* A_coldata = A.colptr(col_A);
00169           
00170           eT acc = eT(0);
00171           for(u32 i=0; i < A_n_rows; ++i)
00172             {
00173             acc += B_rowdata[i] * A_coldata[i];
00174             }
00175         
00176           if( (use_alpha == false) && (use_beta == false) )
00177             {
00178             C.at(col_A,row_B) = acc;
00179             }
00180           else
00181           if( (use_alpha == true) && (use_beta == false) )
00182             {
00183             C.at(col_A,row_B) = alpha * acc;
00184             }
00185           else
00186           if( (use_alpha == false) && (use_beta == true) )
00187             {
00188             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00189             }
00190           else
00191           if( (use_alpha == true) && (use_beta == true) )
00192             {
00193             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00194             }
00195           
00196           }
00197         }
00198       
00199       }
00200     }