#include <gemm.hpp>
Static Public Member Functions | |
template<typename eT > | |
static void | apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0)) |
Definition at line 209 of file gemm.hpp.
static void gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< eT > & | C, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | alpha = eT(1) , |
|||
const eT | beta = eT(0) | |||
) | [inline, static] |
Definition at line 218 of file gemm.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
00225 { 00226 arma_extra_debug_sigprint(); 00227 00228 const u32 A_n_rows = A.n_rows; 00229 const u32 A_n_cols = A.n_cols; 00230 00231 const u32 B_n_rows = B.n_rows; 00232 const u32 B_n_cols = B.n_cols; 00233 00234 if( (do_trans_A == false) && (do_trans_B == false) ) 00235 { 00236 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00237 { 00238 for(u32 col_B = 0; col_B < B_n_cols; ++col_B) 00239 { 00240 const eT* B_coldata = B.colptr(col_B); 00241 00242 eT acc = eT(0); 00243 for(u32 i = 0; i < B_n_rows; ++i) 00244 { 00245 acc += A.at(row_A,i) * B_coldata[i]; 00246 } 00247 00248 if( (use_alpha == false) && (use_beta == false) ) 00249 { 00250 C.at(row_A,col_B) = acc; 00251 } 00252 else 00253 if( (use_alpha == true) && (use_beta == false) ) 00254 { 00255 C.at(row_A,col_B) = alpha * acc; 00256 } 00257 else 00258 if( (use_alpha == false) && (use_beta == true) ) 00259 { 00260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00261 } 00262 else 00263 if( (use_alpha == true) && (use_beta == true) ) 00264 { 00265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00266 } 00267 } 00268 } 00269 } 00270 else 00271 if( (do_trans_A == true) && (do_trans_B == false) ) 00272 { 00273 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00274 { 00275 // col_A is interpreted as row_A when storing the results in matrix C 00276 00277 const eT* A_coldata = A.colptr(col_A); 00278 00279 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00280 { 00281 const eT* B_coldata = B.colptr(col_B); 00282 00283 eT acc = eT(0); 00284 for(u32 i=0; i < B_n_rows; ++i) 00285 { 00286 acc += A_coldata[i] * B_coldata[i]; 00287 } 00288 00289 if( (use_alpha == false) && (use_beta == false) ) 00290 { 00291 C.at(col_A,col_B) = acc; 00292 } 00293 else 00294 if( (use_alpha == true) && (use_beta == false) ) 00295 { 00296 C.at(col_A,col_B) = alpha * acc; 00297 } 00298 else 00299 if( (use_alpha == false) && (use_beta == true) ) 00300 { 00301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00302 } 00303 else 00304 if( (use_alpha == true) && (use_beta == true) ) 00305 { 00306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00307 } 00308 00309 } 00310 } 00311 } 00312 else 00313 if( (do_trans_A == false) && (do_trans_B == true) ) 00314 { 00315 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00316 { 00317 for(u32 row_B = 0; row_B < B_n_rows; ++row_B) 00318 { 00319 eT acc = eT(0); 00320 for(u32 i = 0; i < B_n_cols; ++i) 00321 { 00322 acc += A.at(row_A,i) * B.at(row_B,i); 00323 } 00324 00325 if( (use_alpha == false) && (use_beta == false) ) 00326 { 00327 C.at(row_A,row_B) = acc; 00328 } 00329 else 00330 if( (use_alpha == true) && (use_beta == false) ) 00331 { 00332 C.at(row_A,row_B) = alpha * acc; 00333 } 00334 else 00335 if( (use_alpha == false) && (use_beta == true) ) 00336 { 00337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); 00338 } 00339 else 00340 if( (use_alpha == true) && (use_beta == true) ) 00341 { 00342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); 00343 } 00344 } 00345 } 00346 } 00347 else 00348 if( (do_trans_A == true) && (do_trans_B == true) ) 00349 { 00350 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00351 { 00352 00353 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00354 { 00355 const eT* A_coldata = A.colptr(col_A); 00356 00357 eT acc = eT(0); 00358 for(u32 i=0; i < A_n_rows; ++i) 00359 { 00360 acc += B.at(row_B,i) * A_coldata[i]; 00361 } 00362 00363 if( (use_alpha == false) && (use_beta == false) ) 00364 { 00365 C.at(col_A,row_B) = acc; 00366 } 00367 else 00368 if( (use_alpha == true) && (use_beta == false) ) 00369 { 00370 C.at(col_A,row_B) = alpha * acc; 00371 } 00372 else 00373 if( (use_alpha == false) && (use_beta == true) ) 00374 { 00375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00376 } 00377 else 00378 if( (use_alpha == true) && (use_beta == true) ) 00379 { 00380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00381 } 00382 00383 } 00384 } 00385 00386 } 00387 }