00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00026 class gemm_emul_cache
00027 {
00028 public:
00029
00030 template<typename eT>
00031 inline
00032 static
00033 void
00034 apply
00035 (
00036 Mat<eT>& C,
00037 const Mat<eT>& A,
00038 const Mat<eT>& B,
00039 const eT alpha = eT(1),
00040 const eT beta = eT(0)
00041 )
00042 {
00043 arma_extra_debug_sigprint();
00044
00045 const u32 A_n_rows = A.n_rows;
00046 const u32 A_n_cols = A.n_cols;
00047
00048 const u32 B_n_rows = B.n_rows;
00049 const u32 B_n_cols = B.n_cols;
00050
00051 if( (do_trans_A == false) && (do_trans_B == false) )
00052 {
00053 arma_aligned podarray<eT> tmp(A_n_cols);
00054 eT* A_rowdata = tmp.memptr();
00055
00056 for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00057 {
00058
00059 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00060 {
00061 A_rowdata[col_A] = A.at(row_A,col_A);
00062 }
00063
00064 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00065 {
00066 const eT* B_coldata = B.colptr(col_B);
00067
00068 eT acc = eT(0);
00069 for(u32 i=0; i < B_n_rows; ++i)
00070 {
00071 acc += A_rowdata[i] * B_coldata[i];
00072 }
00073
00074 if( (use_alpha == false) && (use_beta == false) )
00075 {
00076 C.at(row_A,col_B) = acc;
00077 }
00078 else
00079 if( (use_alpha == true) && (use_beta == false) )
00080 {
00081 C.at(row_A,col_B) = alpha * acc;
00082 }
00083 else
00084 if( (use_alpha == false) && (use_beta == true) )
00085 {
00086 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00087 }
00088 else
00089 if( (use_alpha == true) && (use_beta == true) )
00090 {
00091 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00092 }
00093
00094 }
00095 }
00096 }
00097 else
00098 if( (do_trans_A == true) && (do_trans_B == false) )
00099 {
00100 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00101 {
00102
00103
00104 const eT* A_coldata = A.colptr(col_A);
00105
00106 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00107 {
00108 const eT* B_coldata = B.colptr(col_B);
00109
00110 eT acc = eT(0);
00111 for(u32 i=0; i < B_n_rows; ++i)
00112 {
00113 acc += A_coldata[i] * B_coldata[i];
00114 }
00115
00116 if( (use_alpha == false) && (use_beta == false) )
00117 {
00118 C.at(col_A,col_B) = acc;
00119 }
00120 else
00121 if( (use_alpha == true) && (use_beta == false) )
00122 {
00123 C.at(col_A,col_B) = alpha * acc;
00124 }
00125 else
00126 if( (use_alpha == false) && (use_beta == true) )
00127 {
00128 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00129 }
00130 else
00131 if( (use_alpha == true) && (use_beta == true) )
00132 {
00133 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00134 }
00135
00136 }
00137 }
00138 }
00139 else
00140 if( (do_trans_A == false) && (do_trans_B == true) )
00141 {
00142 Mat<eT> B_tmp = trans(B);
00143 gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00144 }
00145 else
00146 if( (do_trans_A == true) && (do_trans_B == true) )
00147 {
00148
00149
00150
00151
00152
00153
00154
00155 arma_aligned podarray<eT> tmp(B.n_cols);
00156 eT* B_rowdata = tmp.memptr();
00157
00158 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00159 {
00160
00161 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00162 {
00163 B_rowdata[col_B] = B.at(row_B,col_B);
00164 }
00165
00166 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00167 {
00168 const eT* A_coldata = A.colptr(col_A);
00169
00170 eT acc = eT(0);
00171 for(u32 i=0; i < A_n_rows; ++i)
00172 {
00173 acc += B_rowdata[i] * A_coldata[i];
00174 }
00175
00176 if( (use_alpha == false) && (use_beta == false) )
00177 {
00178 C.at(col_A,row_B) = acc;
00179 }
00180 else
00181 if( (use_alpha == true) && (use_beta == false) )
00182 {
00183 C.at(col_A,row_B) = alpha * acc;
00184 }
00185 else
00186 if( (use_alpha == false) && (use_beta == true) )
00187 {
00188 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00189 }
00190 else
00191 if( (use_alpha == true) && (use_beta == true) )
00192 {
00193 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00194 }
00195
00196 }
00197 }
00198
00199 }
00200 }
00201
00202 };
00203
00204
00205
00206
00207
00208 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00209 class gemm_emul_simple
00210 {
00211 public:
00212
00213 template<typename eT>
00214 inline
00215 static
00216 void
00217 apply
00218 (
00219 Mat<eT>& C,
00220 const Mat<eT>& A,
00221 const Mat<eT>& B,
00222 const eT alpha = eT(1),
00223 const eT beta = eT(0)
00224 )
00225 {
00226 arma_extra_debug_sigprint();
00227
00228 const u32 A_n_rows = A.n_rows;
00229 const u32 A_n_cols = A.n_cols;
00230
00231 const u32 B_n_rows = B.n_rows;
00232 const u32 B_n_cols = B.n_cols;
00233
00234 if( (do_trans_A == false) && (do_trans_B == false) )
00235 {
00236 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00237 {
00238 for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00239 {
00240 const eT* B_coldata = B.colptr(col_B);
00241
00242 eT acc = eT(0);
00243 for(u32 i = 0; i < B_n_rows; ++i)
00244 {
00245 acc += A.at(row_A,i) * B_coldata[i];
00246 }
00247
00248 if( (use_alpha == false) && (use_beta == false) )
00249 {
00250 C.at(row_A,col_B) = acc;
00251 }
00252 else
00253 if( (use_alpha == true) && (use_beta == false) )
00254 {
00255 C.at(row_A,col_B) = alpha * acc;
00256 }
00257 else
00258 if( (use_alpha == false) && (use_beta == true) )
00259 {
00260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00261 }
00262 else
00263 if( (use_alpha == true) && (use_beta == true) )
00264 {
00265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00266 }
00267 }
00268 }
00269 }
00270 else
00271 if( (do_trans_A == true) && (do_trans_B == false) )
00272 {
00273 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00274 {
00275
00276
00277 const eT* A_coldata = A.colptr(col_A);
00278
00279 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00280 {
00281 const eT* B_coldata = B.colptr(col_B);
00282
00283 eT acc = eT(0);
00284 for(u32 i=0; i < B_n_rows; ++i)
00285 {
00286 acc += A_coldata[i] * B_coldata[i];
00287 }
00288
00289 if( (use_alpha == false) && (use_beta == false) )
00290 {
00291 C.at(col_A,col_B) = acc;
00292 }
00293 else
00294 if( (use_alpha == true) && (use_beta == false) )
00295 {
00296 C.at(col_A,col_B) = alpha * acc;
00297 }
00298 else
00299 if( (use_alpha == false) && (use_beta == true) )
00300 {
00301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00302 }
00303 else
00304 if( (use_alpha == true) && (use_beta == true) )
00305 {
00306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00307 }
00308
00309 }
00310 }
00311 }
00312 else
00313 if( (do_trans_A == false) && (do_trans_B == true) )
00314 {
00315 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00316 {
00317 for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00318 {
00319 eT acc = eT(0);
00320 for(u32 i = 0; i < B_n_cols; ++i)
00321 {
00322 acc += A.at(row_A,i) * B.at(row_B,i);
00323 }
00324
00325 if( (use_alpha == false) && (use_beta == false) )
00326 {
00327 C.at(row_A,row_B) = acc;
00328 }
00329 else
00330 if( (use_alpha == true) && (use_beta == false) )
00331 {
00332 C.at(row_A,row_B) = alpha * acc;
00333 }
00334 else
00335 if( (use_alpha == false) && (use_beta == true) )
00336 {
00337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00338 }
00339 else
00340 if( (use_alpha == true) && (use_beta == true) )
00341 {
00342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00343 }
00344 }
00345 }
00346 }
00347 else
00348 if( (do_trans_A == true) && (do_trans_B == true) )
00349 {
00350 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00351 {
00352
00353 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00354 {
00355 const eT* A_coldata = A.colptr(col_A);
00356
00357 eT acc = eT(0);
00358 for(u32 i=0; i < A_n_rows; ++i)
00359 {
00360 acc += B.at(row_B,i) * A_coldata[i];
00361 }
00362
00363 if( (use_alpha == false) && (use_beta == false) )
00364 {
00365 C.at(col_A,row_B) = acc;
00366 }
00367 else
00368 if( (use_alpha == true) && (use_beta == false) )
00369 {
00370 C.at(col_A,row_B) = alpha * acc;
00371 }
00372 else
00373 if( (use_alpha == false) && (use_beta == true) )
00374 {
00375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00376 }
00377 else
00378 if( (use_alpha == true) && (use_beta == true) )
00379 {
00380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00381 }
00382
00383 }
00384 }
00385
00386 }
00387 }
00388
00389 };
00390
00391
00392
00393
00394
00395
00396
00397 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00398 class gemm
00399 {
00400 public:
00401
00402 template<typename eT>
00403 inline
00404 static
00405 void
00406 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00407 {
00408 arma_extra_debug_sigprint();
00409
00410 if( ((A.n_elem <= 64u) && (B.n_elem <= 64u)) )
00411 {
00412 gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00413 }
00414 else
00415 {
00416 #if defined(ARMA_USE_ATLAS)
00417 {
00418 arma_extra_debug_print("atlas::cblas_gemm()");
00419
00420 atlas::cblas_gemm<eT>
00421 (
00422 atlas::CblasColMajor,
00423 (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00424 (do_trans_B) ? atlas::CblasTrans : atlas::CblasNoTrans,
00425 C.n_rows,
00426 C.n_cols,
00427 (do_trans_A) ? A.n_rows : A.n_cols,
00428 (use_alpha) ? alpha : eT(1),
00429 A.mem,
00430 (do_trans_A) ? A.n_rows : C.n_rows,
00431 B.mem,
00432 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
00433 (use_beta) ? beta : eT(0),
00434 C.memptr(),
00435 C.n_rows
00436 );
00437 }
00438 #elif defined(ARMA_USE_BLAS)
00439 {
00440 arma_extra_debug_print("blas::gemm_()");
00441
00442 const char trans_A = (do_trans_A) ? 'T' : 'N';
00443 const char trans_B = (do_trans_B) ? 'T' : 'N';
00444
00445 const int m = C.n_rows;
00446 const int n = C.n_cols;
00447 const int k = (do_trans_A) ? A.n_rows : A.n_cols;
00448
00449 const eT local_alpha = (use_alpha) ? alpha : eT(1);
00450
00451 const int lda = (do_trans_A) ? k : m;
00452 const int ldb = (do_trans_B) ? n : k;
00453
00454 const eT local_beta = (use_beta) ? beta : eT(0);
00455
00456 arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_A = %c") % trans_A );
00457 arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_B = %c") % trans_B );
00458
00459 blas::gemm_<eT>
00460 (
00461 &trans_A,
00462 &trans_B,
00463 &m,
00464 &n,
00465 &k,
00466 &local_alpha,
00467 A.mem,
00468 &lda,
00469 B.mem,
00470 &ldb,
00471 &local_beta,
00472 C.memptr(),
00473 &m
00474 );
00475 }
00476 #else
00477 {
00478 gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00479 }
00480 #endif
00481 }
00482 }
00483
00484
00485
00486
00487 template<typename eT>
00488 inline
00489 static
00490 void
00491 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00492 {
00493 if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00494 {
00495 gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00496 }
00497 else
00498 {
00499 gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00500 }
00501 }
00502
00503
00504
00505 inline
00506 static
00507 void
00508 apply
00509 (
00510 Mat<float>& C,
00511 const Mat<float>& A,
00512 const Mat<float>& B,
00513 const float alpha = float(1),
00514 const float beta = float(0)
00515 )
00516 {
00517 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00518 }
00519
00520
00521
00522 inline
00523 static
00524 void
00525 apply
00526 (
00527 Mat<double>& C,
00528 const Mat<double>& A,
00529 const Mat<double>& B,
00530 const double alpha = double(1),
00531 const double beta = double(0)
00532 )
00533 {
00534 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00535 }
00536
00537
00538
00539 inline
00540 static
00541 void
00542 apply
00543 (
00544 Mat< std::complex<float> >& C,
00545 const Mat< std::complex<float> >& A,
00546 const Mat< std::complex<float> >& B,
00547 const std::complex<float> alpha = std::complex<float>(1),
00548 const std::complex<float> beta = std::complex<float>(0)
00549 )
00550 {
00551 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00552 }
00553
00554
00555
00556 inline
00557 static
00558 void
00559 apply
00560 (
00561 Mat< std::complex<double> >& C,
00562 const Mat< std::complex<double> >& A,
00563 const Mat< std::complex<double> >& B,
00564 const std::complex<double> alpha = std::complex<double>(1),
00565 const std::complex<double> beta = std::complex<double>(0)
00566 )
00567 {
00568 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00569 }
00570
00571 };
00572
00573
00574
00575