gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm_mixed]

Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm_mixed.hpp>

List of all members.

Static Public Member Functions

template<typename out_eT , typename in_eT1 , typename in_eT2 >
static void apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0))


Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >

Matrix multplication where the matrices have different element types. Uses caching for speedup. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 27 of file gemm_mixed.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename out_eT , typename in_eT1 , typename in_eT2 >
static void gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< out_eT > &  C,
const Mat< in_eT1 > &  A,
const Mat< in_eT2 > &  B,
const out_eT  alpha = out_eT(1),
const out_eT  beta = out_eT(0) 
) [inline, static]

Definition at line 36 of file gemm_mixed.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), podarray< T1 >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().

00043     {
00044     arma_extra_debug_sigprint();
00045     
00046     const u32 A_n_rows = A.n_rows;
00047     const u32 A_n_cols = A.n_cols;
00048     
00049     const u32 B_n_rows = B.n_rows;
00050     const u32 B_n_cols = B.n_cols;
00051     
00052     if( (do_trans_A == false) && (do_trans_B == false) )
00053       {
00054       podarray<in_eT1> tmp(A_n_cols);
00055       in_eT1* A_rowdata = tmp.memptr();
00056       
00057       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00058         {
00059         
00060         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00061           {
00062           A_rowdata[col_A] = A.at(row_A,col_A);
00063           }
00064         
00065         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00066           {
00067           const in_eT2* B_coldata = B.colptr(col_B);
00068           
00069           out_eT acc = out_eT(0);
00070           for(u32 i=0; i < B_n_rows; ++i)
00071             {
00072             acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00073             }
00074         
00075           if( (use_alpha == false) && (use_beta == false) )
00076             {
00077             C.at(row_A,col_B) = acc;
00078             }
00079           else
00080           if( (use_alpha == true) && (use_beta == false) )
00081             {
00082             C.at(row_A,col_B) = alpha * acc;
00083             }
00084           else
00085           if( (use_alpha == false) && (use_beta == true) )
00086             {
00087             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00088             }
00089           else
00090           if( (use_alpha == true) && (use_beta == true) )
00091             {
00092             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00093             }
00094           
00095           }
00096         }
00097       }
00098     else
00099     if( (do_trans_A == true) && (do_trans_B == false) )
00100       {
00101       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00102         {
00103         // col_A is interpreted as row_A when storing the results in matrix C
00104         
00105         const in_eT1* A_coldata = A.colptr(col_A);
00106         
00107         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00108           {
00109           const in_eT2* B_coldata = B.colptr(col_B);
00110           
00111           out_eT acc = out_eT(0);
00112           for(u32 i=0; i < B_n_rows; ++i)
00113             {
00114             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00115             }
00116         
00117           if( (use_alpha == false) && (use_beta == false) )
00118             {
00119             C.at(col_A,col_B) = acc;
00120             }
00121           else
00122           if( (use_alpha == true) && (use_beta == false) )
00123             {
00124             C.at(col_A,col_B) = alpha * acc;
00125             }
00126           else
00127           if( (use_alpha == false) && (use_beta == true) )
00128             {
00129             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00130             }
00131           else
00132           if( (use_alpha == true) && (use_beta == true) )
00133             {
00134             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00135             }
00136           
00137           }
00138         }
00139       }
00140     else
00141     if( (do_trans_A == false) && (do_trans_B == true) )
00142       {
00143       Mat<in_eT2> B_tmp = trans(B);
00144       gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00145       }
00146     else
00147     if( (do_trans_A == true) && (do_trans_B == true) )
00148       {
00149       // mat B_tmp = trans(B);
00150       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00151       
00152       
00153       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00154       // transpose operations are not needed
00155       
00156       podarray<in_eT2> tmp(B.n_cols);
00157       in_eT2* B_rowdata = tmp.memptr();
00158       
00159       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00160         {
00161         
00162         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00163           {
00164           B_rowdata[col_B] = B.at(row_B,col_B);
00165           }
00166         
00167         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00168           {
00169           const in_eT1* A_coldata = A.colptr(col_A);
00170           
00171           out_eT acc = out_eT(0);
00172           for(u32 i=0; i < A_n_rows; ++i)
00173             {
00174             acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00175             }
00176         
00177           if( (use_alpha == false) && (use_beta == false) )
00178             {
00179             C.at(col_A,row_B) = acc;
00180             }
00181           else
00182           if( (use_alpha == true) && (use_beta == false) )
00183             {
00184             C.at(col_A,row_B) = alpha * acc;
00185             }
00186           else
00187           if( (use_alpha == false) && (use_beta == true) )
00188             {
00189             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00190             }
00191           else
00192           if( (use_alpha == true) && (use_beta == true) )
00193             {
00194             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00195             }
00196           
00197           }
00198         }
00199       
00200       }
00201     }