#include <gemm_mixed.hpp>
Static Public Member Functions | |
template<typename out_eT , typename in_eT1 , typename in_eT2 > | |
static void | apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0)) |
Definition at line 27 of file gemm_mixed.hpp.
static void gemm_mixed_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< out_eT > & | C, | |
const Mat< in_eT1 > & | A, | |||
const Mat< in_eT2 > & | B, | |||
const out_eT | alpha = out_eT(1) , |
|||
const out_eT | beta = out_eT(0) | |||
) | [inline, static] |
Definition at line 36 of file gemm_mixed.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), podarray< T1 >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00043 { 00044 arma_extra_debug_sigprint(); 00045 00046 const u32 A_n_rows = A.n_rows; 00047 const u32 A_n_cols = A.n_cols; 00048 00049 const u32 B_n_rows = B.n_rows; 00050 const u32 B_n_cols = B.n_cols; 00051 00052 if( (do_trans_A == false) && (do_trans_B == false) ) 00053 { 00054 podarray<in_eT1> tmp(A_n_cols); 00055 in_eT1* A_rowdata = tmp.memptr(); 00056 00057 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00058 { 00059 00060 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00061 { 00062 A_rowdata[col_A] = A.at(row_A,col_A); 00063 } 00064 00065 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00066 { 00067 const in_eT2* B_coldata = B.colptr(col_B); 00068 00069 out_eT acc = out_eT(0); 00070 for(u32 i=0; i < B_n_rows; ++i) 00071 { 00072 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00073 } 00074 00075 if( (use_alpha == false) && (use_beta == false) ) 00076 { 00077 C.at(row_A,col_B) = acc; 00078 } 00079 else 00080 if( (use_alpha == true) && (use_beta == false) ) 00081 { 00082 C.at(row_A,col_B) = alpha * acc; 00083 } 00084 else 00085 if( (use_alpha == false) && (use_beta == true) ) 00086 { 00087 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00088 } 00089 else 00090 if( (use_alpha == true) && (use_beta == true) ) 00091 { 00092 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00093 } 00094 00095 } 00096 } 00097 } 00098 else 00099 if( (do_trans_A == true) && (do_trans_B == false) ) 00100 { 00101 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00102 { 00103 // col_A is interpreted as row_A when storing the results in matrix C 00104 00105 const in_eT1* A_coldata = A.colptr(col_A); 00106 00107 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00108 { 00109 const in_eT2* B_coldata = B.colptr(col_B); 00110 00111 out_eT acc = out_eT(0); 00112 for(u32 i=0; i < B_n_rows; ++i) 00113 { 00114 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00115 } 00116 00117 if( (use_alpha == false) && (use_beta == false) ) 00118 { 00119 C.at(col_A,col_B) = acc; 00120 } 00121 else 00122 if( (use_alpha == true) && (use_beta == false) ) 00123 { 00124 C.at(col_A,col_B) = alpha * acc; 00125 } 00126 else 00127 if( (use_alpha == false) && (use_beta == true) ) 00128 { 00129 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00130 } 00131 else 00132 if( (use_alpha == true) && (use_beta == true) ) 00133 { 00134 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00135 } 00136 00137 } 00138 } 00139 } 00140 else 00141 if( (do_trans_A == false) && (do_trans_B == true) ) 00142 { 00143 Mat<in_eT2> B_tmp = trans(B); 00144 gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00145 } 00146 else 00147 if( (do_trans_A == true) && (do_trans_B == true) ) 00148 { 00149 // mat B_tmp = trans(B); 00150 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00151 00152 00153 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00154 // transpose operations are not needed 00155 00156 podarray<in_eT2> tmp(B.n_cols); 00157 in_eT2* B_rowdata = tmp.memptr(); 00158 00159 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00160 { 00161 00162 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00163 { 00164 B_rowdata[col_B] = B.at(row_B,col_B); 00165 } 00166 00167 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00168 { 00169 const in_eT1* A_coldata = A.colptr(col_A); 00170 00171 out_eT acc = out_eT(0); 00172 for(u32 i=0; i < A_n_rows; ++i) 00173 { 00174 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00175 } 00176 00177 if( (use_alpha == false) && (use_beta == false) ) 00178 { 00179 C.at(col_A,row_B) = acc; 00180 } 00181 else 00182 if( (use_alpha == true) && (use_beta == false) ) 00183 { 00184 C.at(col_A,row_B) = alpha * acc; 00185 } 00186 else 00187 if( (use_alpha == false) && (use_beta == true) ) 00188 { 00189 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00190 } 00191 else 00192 if( (use_alpha == true) && (use_beta == true) ) 00193 { 00194 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00195 } 00196 00197 } 00198 } 00199 00200 } 00201 }