Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...
#include <gemm_mixed.hpp>
Static Public Member Functions | |
template<typename out_eT , typename in_eT1 , typename in_eT2 > | |
static arma_hot void | apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0)) |
Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).
Definition at line 213 of file gemm_mixed.hpp.
static arma_hot void gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply | ( | Mat< out_eT > & | C, | |
const Mat< in_eT1 > & | A, | |||
const Mat< in_eT2 > & | B, | |||
const out_eT | alpha = out_eT(1) , |
|||
const out_eT | beta = out_eT(0) | |||
) | [inline, static] |
Definition at line 223 of file gemm_mixed.hpp.
References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
00230 { 00231 arma_extra_debug_sigprint(); 00232 00233 const u32 A_n_rows = A.n_rows; 00234 const u32 A_n_cols = A.n_cols; 00235 00236 const u32 B_n_rows = B.n_rows; 00237 const u32 B_n_cols = B.n_cols; 00238 00239 if( (do_trans_A == false) && (do_trans_B == false) ) 00240 { 00241 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00242 { 00243 for(u32 col_B = 0; col_B < B_n_cols; ++col_B) 00244 { 00245 const in_eT2* B_coldata = B.colptr(col_B); 00246 00247 out_eT acc = out_eT(0); 00248 for(u32 i = 0; i < B_n_rows; ++i) 00249 { 00250 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)); 00251 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00252 acc += val1 * val2; 00253 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00254 } 00255 00256 if( (use_alpha == false) && (use_beta == false) ) 00257 { 00258 C.at(row_A,col_B) = acc; 00259 } 00260 else 00261 if( (use_alpha == true) && (use_beta == false) ) 00262 { 00263 C.at(row_A,col_B) = alpha * acc; 00264 } 00265 else 00266 if( (use_alpha == false) && (use_beta == true) ) 00267 { 00268 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); 00269 } 00270 else 00271 if( (use_alpha == true) && (use_beta == true) ) 00272 { 00273 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); 00274 } 00275 } 00276 } 00277 } 00278 else 00279 if( (do_trans_A == true) && (do_trans_B == false) ) 00280 { 00281 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00282 { 00283 // col_A is interpreted as row_A when storing the results in matrix C 00284 00285 const in_eT1* A_coldata = A.colptr(col_A); 00286 00287 for(u32 col_B=0; col_B < B_n_cols; ++col_B) 00288 { 00289 const in_eT2* B_coldata = B.colptr(col_B); 00290 00291 out_eT acc = out_eT(0); 00292 for(u32 i=0; i < B_n_rows; ++i) 00293 { 00294 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); 00295 } 00296 00297 if( (use_alpha == false) && (use_beta == false) ) 00298 { 00299 C.at(col_A,col_B) = acc; 00300 } 00301 else 00302 if( (use_alpha == true) && (use_beta == false) ) 00303 { 00304 C.at(col_A,col_B) = alpha * acc; 00305 } 00306 else 00307 if( (use_alpha == false) && (use_beta == true) ) 00308 { 00309 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); 00310 } 00311 else 00312 if( (use_alpha == true) && (use_beta == true) ) 00313 { 00314 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); 00315 } 00316 00317 } 00318 } 00319 } 00320 else 00321 if( (do_trans_A == false) && (do_trans_B == true) ) 00322 { 00323 for(u32 row_A = 0; row_A < A_n_rows; ++row_A) 00324 { 00325 for(u32 row_B = 0; row_B < B_n_rows; ++row_B) 00326 { 00327 out_eT acc = out_eT(0); 00328 for(u32 i = 0; i < B_n_cols; ++i) 00329 { 00330 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)); 00331 } 00332 00333 if( (use_alpha == false) && (use_beta == false) ) 00334 { 00335 C.at(row_A,row_B) = acc; 00336 } 00337 else 00338 if( (use_alpha == true) && (use_beta == false) ) 00339 { 00340 C.at(row_A,row_B) = alpha * acc; 00341 } 00342 else 00343 if( (use_alpha == false) && (use_beta == true) ) 00344 { 00345 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); 00346 } 00347 else 00348 if( (use_alpha == true) && (use_beta == true) ) 00349 { 00350 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); 00351 } 00352 } 00353 } 00354 } 00355 else 00356 if( (do_trans_A == true) && (do_trans_B == true) ) 00357 { 00358 for(u32 row_B=0; row_B < B_n_rows; ++row_B) 00359 { 00360 00361 for(u32 col_A=0; col_A < A_n_cols; ++col_A) 00362 { 00363 const in_eT1* A_coldata = A.colptr(col_A); 00364 00365 out_eT acc = out_eT(0); 00366 for(u32 i=0; i < A_n_rows; ++i) 00367 { 00368 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); 00369 } 00370 00371 if( (use_alpha == false) && (use_beta == false) ) 00372 { 00373 C.at(col_A,row_B) = acc; 00374 } 00375 else 00376 if( (use_alpha == true) && (use_beta == false) ) 00377 { 00378 C.at(col_A,row_B) = alpha * acc; 00379 } 00380 else 00381 if( (use_alpha == false) && (use_beta == true) ) 00382 { 00383 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); 00384 } 00385 else 00386 if( (use_alpha == true) && (use_beta == true) ) 00387 { 00388 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); 00389 } 00390 00391 } 00392 } 00393 00394 } 00395 }