gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm_mixed]

Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm_mixed.hpp>

List of all members.

Static Public Member Functions

template<typename out_eT , typename in_eT1 , typename in_eT2 >
static void apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0))


Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >

Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 211 of file gemm_mixed.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename out_eT , typename in_eT1 , typename in_eT2 >
static void gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< out_eT > &  C,
const Mat< in_eT1 > &  A,
const Mat< in_eT2 > &  B,
const out_eT  alpha = out_eT(1),
const out_eT  beta = out_eT(0) 
) [inline, static]

Definition at line 220 of file gemm_mixed.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.

00227     {
00228     arma_extra_debug_sigprint();
00229     
00230     const u32 A_n_rows = A.n_rows;
00231     const u32 A_n_cols = A.n_cols;
00232     
00233     const u32 B_n_rows = B.n_rows;
00234     const u32 B_n_cols = B.n_cols;
00235     
00236     if( (do_trans_A == false) && (do_trans_B == false) )
00237       {
00238       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00239         {
00240         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00241           {
00242           const in_eT2* B_coldata = B.colptr(col_B);
00243           
00244           out_eT acc = out_eT(0);
00245           for(u32 i = 0; i < B_n_rows; ++i)
00246             {
00247             const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00248             const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00249             acc += val1 * val2;
00250             //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00251             }
00252           
00253           if( (use_alpha == false) && (use_beta == false) )
00254             {
00255             C.at(row_A,col_B) = acc;
00256             }
00257           else
00258           if( (use_alpha == true) && (use_beta == false) )
00259             {
00260             C.at(row_A,col_B) = alpha * acc;
00261             }
00262           else
00263           if( (use_alpha == false) && (use_beta == true) )
00264             {
00265             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00266             }
00267           else
00268           if( (use_alpha == true) && (use_beta == true) )
00269             {
00270             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00271             }
00272           }
00273         }
00274       }
00275     else
00276     if( (do_trans_A == true) && (do_trans_B == false) )
00277       {
00278       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00279         {
00280         // col_A is interpreted as row_A when storing the results in matrix C
00281         
00282         const in_eT1* A_coldata = A.colptr(col_A);
00283         
00284         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00285           {
00286           const in_eT2* B_coldata = B.colptr(col_B);
00287           
00288           out_eT acc = out_eT(0);
00289           for(u32 i=0; i < B_n_rows; ++i)
00290             {
00291             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00292             }
00293         
00294           if( (use_alpha == false) && (use_beta == false) )
00295             {
00296             C.at(col_A,col_B) = acc;
00297             }
00298           else
00299           if( (use_alpha == true) && (use_beta == false) )
00300             {
00301             C.at(col_A,col_B) = alpha * acc;
00302             }
00303           else
00304           if( (use_alpha == false) && (use_beta == true) )
00305             {
00306             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00307             }
00308           else
00309           if( (use_alpha == true) && (use_beta == true) )
00310             {
00311             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00312             }
00313           
00314           }
00315         }
00316       }
00317     else
00318     if( (do_trans_A == false) && (do_trans_B == true) )
00319       {
00320       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00321         {
00322         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00323           {
00324           out_eT acc = out_eT(0);
00325           for(u32 i = 0; i < B_n_cols; ++i)
00326             {
00327             acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00328             }
00329           
00330           if( (use_alpha == false) && (use_beta == false) )
00331             {
00332             C.at(row_A,row_B) = acc;
00333             }
00334           else
00335           if( (use_alpha == true) && (use_beta == false) )
00336             {
00337             C.at(row_A,row_B) = alpha * acc;
00338             }
00339           else
00340           if( (use_alpha == false) && (use_beta == true) )
00341             {
00342             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00343             }
00344           else
00345           if( (use_alpha == true) && (use_beta == true) )
00346             {
00347             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00348             }
00349           }
00350         }
00351       }
00352     else
00353     if( (do_trans_A == true) && (do_trans_B == true) )
00354       {
00355       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00356         {
00357         
00358         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00359           {
00360           const in_eT1* A_coldata = A.colptr(col_A);
00361           
00362           out_eT acc = out_eT(0);
00363           for(u32 i=0; i < A_n_rows; ++i)
00364             {
00365             acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00366             }
00367         
00368           if( (use_alpha == false) && (use_beta == false) )
00369             {
00370             C.at(col_A,row_B) = acc;
00371             }
00372           else
00373           if( (use_alpha == true) && (use_beta == false) )
00374             {
00375             C.at(col_A,row_B) = alpha * acc;
00376             }
00377           else
00378           if( (use_alpha == false) && (use_beta == true) )
00379             {
00380             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00381             }
00382           else
00383           if( (use_alpha == true) && (use_beta == true) )
00384             {
00385             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00386             }
00387           
00388           }
00389         }
00390       
00391       }
00392     }