#include <gemm.hpp>
Static Public Member Functions | |
template<typename eT > | |
static void | apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0)) |
Definition at line 26 of file gemm.hpp.
static void gemm_emul_cache< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< eT > & | C, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | alpha = eT(1) , |
|||
const eT | beta = eT(0) | |||
) | [inline, static] |
Definition at line 35 of file gemm.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and trans().
00042 { 00043 arma_extra_debug_sigprint(); 00044 00045 const u32 A_n_rows = A.n_rows; 00046 const u32 A_n_cols = A.n_cols; 00047 00048 const u32 B_n_rows = B.n_rows; 00049 const u32 B_n_cols = B.n_cols; 00050 00051 if( (do_trans_A == false) && (do_trans_B == false) ) 00052 { 00053 arma_aligned podarray<eT> tmp(A_n_cols); 00054 eT* A_rowdata = tmp.memptr(); 00055 00056 for(u32 row_A=0; row_A < A_n_rows; ++row_A) 00057 { 00058 00059 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00060 { 00061 A_rowdata[col_A] = A.at(row_A,col_A); 00062 } 00063 00064 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00065 { 00066 const eT* B_coldata = B.colptr(col_B); 00067 00068 eT acc = eT(0); 00069 for(u32 i=0; i < B_n_rows; ++i) 00070 { 00071 acc += A_rowdata[i] * B_coldata[i]; 00072 } 00073 00074 if( (use_alpha == false) && (use_beta == false) ) 00075 { 00076 C.at(row_A,col_B) = acc; 00077 } 00078 else 00079 if( (use_alpha == true) && (use_beta == false) ) 00080 { 00081 C.at(row_A,col_B) = alpha * acc; 00082 } 00083 else 00084 if( (use_alpha == false) && (use_beta == true) ) 00085 { 00086 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00087 } 00088 else 00089 if( (use_alpha == true) && (use_beta == true) ) 00090 { 00091 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00092 } 00093 00094 } 00095 } 00096 } 00097 else 00098 if( (do_trans_A == true) && (do_trans_B == false) ) 00099 { 00100 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00101 { 00102 // col_A is interpreted as row_A when storing the results in matrix C 00103 00104 const eT* A_coldata = A.colptr(col_A); 00105 00106 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00107 { 00108 const eT* B_coldata = B.colptr(col_B); 00109 00110 eT acc = eT(0); 00111 for(u32 i=0; i < B_n_rows; ++i) 00112 { 00113 acc += A_coldata[i] * B_coldata[i]; 00114 } 00115 00116 if( (use_alpha == false) && (use_beta == false) ) 00117 { 00118 C.at(col_A,col_B) = acc; 00119 } 00120 else 00121 if( (use_alpha == true) && (use_beta == false) ) 00122 { 00123 C.at(col_A,col_B) = alpha * acc; 00124 } 00125 else 00126 if( (use_alpha == false) && (use_beta == true) ) 00127 { 00128 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00129 } 00130 else 00131 if( (use_alpha == true) && (use_beta == true) ) 00132 { 00133 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00134 } 00135 00136 } 00137 } 00138 } 00139 else 00140 if( (do_trans_A == false) && (do_trans_B == true) ) 00141 { 00142 Mat<eT> B_tmp = trans(B); 00143 gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00144 } 00145 else 00146 if( (do_trans_A == true) && (do_trans_B == true) ) 00147 { 00148 // mat B_tmp = trans(B); 00149 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 00150 00151 00152 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 00153 // transpose operations are not needed 00154 00155 arma_aligned podarray<eT> tmp(B.n_cols); 00156 eT* B_rowdata = tmp.memptr(); 00157 00158 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00159 { 00160 00161 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00162 { 00163 B_rowdata[col_B] = B.at(row_B,col_B); 00164 } 00165 00166 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00167 { 00168 const eT* A_coldata = A.colptr(col_A); 00169 00170 eT acc = eT(0); 00171 for(u32 i=0; i < A_n_rows; ++i) 00172 { 00173 acc += B_rowdata[i] * A_coldata[i]; 00174 } 00175 00176 if( (use_alpha == false) && (use_beta == false) ) 00177 { 00178 C.at(col_A,row_B) = acc; 00179 } 00180 else 00181 if( (use_alpha == true) && (use_beta == false) ) 00182 { 00183 C.at(col_A,row_B) = alpha * acc; 00184 } 00185 else 00186 if( (use_alpha == false) && (use_beta == true) ) 00187 { 00188 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00189 } 00190 else 00191 if( (use_alpha == true) && (use_beta == true) ) 00192 { 00193 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00194 } 00195 00196 } 00197 } 00198 00199 } 00200 }