#include <gemm_mixed.hpp>
Static Public Member Functions | |
template<typename out_eT , typename in_eT1 , typename in_eT2 > | |
static void | apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0)) |
Definition at line 211 of file gemm_mixed.hpp.
static void gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< out_eT > & | C, | |
const Mat< in_eT1 > & | A, | |||
const Mat< in_eT2 > & | B, | |||
const out_eT | alpha = out_eT(1) , |
|||
const out_eT | beta = out_eT(0) | |||
) | [inline, static] |
Definition at line 220 of file gemm_mixed.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
00227 { 00228 arma_extra_debug_sigprint(); 00229 00230 const u32 A_n_rows = A.n_rows; 00231 const u32 A_n_cols = A.n_cols; 00232 00233 const u32 B_n_rows = B.n_rows; 00234 const u32 B_n_cols = B.n_cols; 00235 00236 if( (do_trans_A == false) && (do_trans_B == false) ) 00237 { 00238 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00239 { 00240 for(u32 col_B = 0; col_B < B_n_cols; ++col_B) 00241 { 00242 const in_eT2* B_coldata = B.colptr(col_B); 00243 00244 out_eT acc = out_eT(0); 00245 for(u32 i = 0; i < B_n_rows; ++i) 00246 { 00247 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)); 00248 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00249 acc += val1 * val2; 00250 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00251 } 00252 00253 if( (use_alpha == false) && (use_beta == false) ) 00254 { 00255 C.at(row_A,col_B) = acc; 00256 } 00257 else 00258 if( (use_alpha == true) && (use_beta == false) ) 00259 { 00260 C.at(row_A,col_B) = alpha * acc; 00261 } 00262 else 00263 if( (use_alpha == false) && (use_beta == true) ) 00264 { 00265 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00266 } 00267 else 00268 if( (use_alpha == true) && (use_beta == true) ) 00269 { 00270 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00271 } 00272 } 00273 } 00274 } 00275 else 00276 if( (do_trans_A == true) && (do_trans_B == false) ) 00277 { 00278 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00279 { 00280 // col_A is interpreted as row_A when storing the results in matrix C 00281 00282 const in_eT1* A_coldata = A.colptr(col_A); 00283 00284 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00285 { 00286 const in_eT2* B_coldata = B.colptr(col_B); 00287 00288 out_eT acc = out_eT(0); 00289 for(u32 i=0; i < B_n_rows; ++i) 00290 { 00291 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00292 } 00293 00294 if( (use_alpha == false) && (use_beta == false) ) 00295 { 00296 C.at(col_A,col_B) = acc; 00297 } 00298 else 00299 if( (use_alpha == true) && (use_beta == false) ) 00300 { 00301 C.at(col_A,col_B) = alpha * acc; 00302 } 00303 else 00304 if( (use_alpha == false) && (use_beta == true) ) 00305 { 00306 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00307 } 00308 else 00309 if( (use_alpha == true) && (use_beta == true) ) 00310 { 00311 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00312 } 00313 00314 } 00315 } 00316 } 00317 else 00318 if( (do_trans_A == false) && (do_trans_B == true) ) 00319 { 00320 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00321 { 00322 for(u32 row_B = 0; row_B < B_n_rows; ++row_B) 00323 { 00324 out_eT acc = out_eT(0); 00325 for(u32 i = 0; i < B_n_cols; ++i) 00326 { 00327 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)); 00328 } 00329 00330 if( (use_alpha == false) && (use_beta == false) ) 00331 { 00332 C.at(row_A,row_B) = acc; 00333 } 00334 else 00335 if( (use_alpha == true) && (use_beta == false) ) 00336 { 00337 C.at(row_A,row_B) = alpha * acc; 00338 } 00339 else 00340 if( (use_alpha == false) && (use_beta == true) ) 00341 { 00342 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); 00343 } 00344 else 00345 if( (use_alpha == true) && (use_beta == true) ) 00346 { 00347 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); 00348 } 00349 } 00350 } 00351 } 00352 else 00353 if( (do_trans_A == true) && (do_trans_B == true) ) 00354 { 00355 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00356 { 00357 00358 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00359 { 00360 const in_eT1* A_coldata = A.colptr(col_A); 00361 00362 out_eT acc = out_eT(0); 00363 for(u32 i=0; i < A_n_rows; ++i) 00364 { 00365 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00366 } 00367 00368 if( (use_alpha == false) && (use_beta == false) ) 00369 { 00370 C.at(col_A,row_B) = acc; 00371 } 00372 else 00373 if( (use_alpha == true) && (use_beta == false) ) 00374 { 00375 C.at(col_A,row_B) = alpha * acc; 00376 } 00377 else 00378 if( (use_alpha == false) && (use_beta == true) ) 00379 { 00380 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00381 } 00382 else 00383 if( (use_alpha == true) && (use_beta == true) ) 00384 { 00385 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00386 } 00387 00388 } 00389 } 00390 00391 } 00392 }