gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm]

Partial emulation of ATLAS/BLAS gemm(), non-cached version. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm.hpp>

List of all members.

Static Public Member Functions

template<typename eT >
static void apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))


Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta >

Partial emulation of ATLAS/BLAS gemm(), non-cached version. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 209 of file gemm.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename eT >
static void gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< eT > &  C,
const Mat< eT > &  A,
const Mat< eT > &  B,
const eT  alpha = eT(1),
const eT  beta = eT(0) 
) [inline, static]

Definition at line 218 of file gemm.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.

00225     {
00226     arma_extra_debug_sigprint();
00227     
00228     const u32 A_n_rows = A.n_rows;
00229     const u32 A_n_cols = A.n_cols;
00230     
00231     const u32 B_n_rows = B.n_rows;
00232     const u32 B_n_cols = B.n_cols;
00233     
00234     if( (do_trans_A == false) && (do_trans_B == false) )
00235       {
00236       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00237         {
00238         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00239           {
00240           const eT* B_coldata = B.colptr(col_B);
00241           
00242           eT acc = eT(0);
00243           for(u32 i = 0; i < B_n_rows; ++i)
00244             {
00245             acc += A.at(row_A,i) * B_coldata[i];
00246             }
00247           
00248           if( (use_alpha == false) && (use_beta == false) )
00249             {
00250             C.at(row_A,col_B) = acc;
00251             }
00252           else
00253           if( (use_alpha == true) && (use_beta == false) )
00254             {
00255             C.at(row_A,col_B) = alpha * acc;
00256             }
00257           else
00258           if( (use_alpha == false) && (use_beta == true) )
00259             {
00260             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00261             }
00262           else
00263           if( (use_alpha == true) && (use_beta == true) )
00264             {
00265             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00266             }
00267           }
00268         }
00269       }
00270     else
00271     if( (do_trans_A == true) && (do_trans_B == false) )
00272       {
00273       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00274         {
00275         // col_A is interpreted as row_A when storing the results in matrix C
00276         
00277         const eT* A_coldata = A.colptr(col_A);
00278         
00279         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00280           {
00281           const eT* B_coldata = B.colptr(col_B);
00282           
00283           eT acc = eT(0);
00284           for(u32 i=0; i < B_n_rows; ++i)
00285             {
00286             acc += A_coldata[i] * B_coldata[i];
00287             }
00288         
00289           if( (use_alpha == false) && (use_beta == false) )
00290             {
00291             C.at(col_A,col_B) = acc;
00292             }
00293           else
00294           if( (use_alpha == true) && (use_beta == false) )
00295             {
00296             C.at(col_A,col_B) = alpha * acc;
00297             }
00298           else
00299           if( (use_alpha == false) && (use_beta == true) )
00300             {
00301             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00302             }
00303           else
00304           if( (use_alpha == true) && (use_beta == true) )
00305             {
00306             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00307             }
00308           
00309           }
00310         }
00311       }
00312     else
00313     if( (do_trans_A == false) && (do_trans_B == true) )
00314       {
00315       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00316         {
00317         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00318           {
00319           eT acc = eT(0);
00320           for(u32 i = 0; i < B_n_cols; ++i)
00321             {
00322             acc += A.at(row_A,i) * B.at(row_B,i);
00323             }
00324           
00325           if( (use_alpha == false) && (use_beta == false) )
00326             {
00327             C.at(row_A,row_B) = acc;
00328             }
00329           else
00330           if( (use_alpha == true) && (use_beta == false) )
00331             {
00332             C.at(row_A,row_B) = alpha * acc;
00333             }
00334           else
00335           if( (use_alpha == false) && (use_beta == true) )
00336             {
00337             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00338             }
00339           else
00340           if( (use_alpha == true) && (use_beta == true) )
00341             {
00342             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00343             }
00344           }
00345         }
00346       }
00347     else
00348     if( (do_trans_A == true) && (do_trans_B == true) )
00349       {
00350       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00351         {
00352         
00353         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00354           {
00355           const eT* A_coldata = A.colptr(col_A);
00356           
00357           eT acc = eT(0);
00358           for(u32 i=0; i < A_n_rows; ++i)
00359             {
00360             acc += B.at(row_B,i) * A_coldata[i];
00361             }
00362         
00363           if( (use_alpha == false) && (use_beta == false) )
00364             {
00365             C.at(col_A,row_B) = acc;
00366             }
00367           else
00368           if( (use_alpha == true) && (use_beta == false) )
00369             {
00370             C.at(col_A,row_B) = alpha * acc;
00371             }
00372           else
00373           if( (use_alpha == false) && (use_beta == true) )
00374             {
00375             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00376             }
00377           else
00378           if( (use_alpha == true) && (use_beta == true) )
00379             {
00380             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00381             }
00382           
00383           }
00384         }
00385       
00386       }
00387     }