00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00027 class gemm_mixed_cache
00028 {
00029 public:
00030
00031 template<typename out_eT, typename in_eT1, typename in_eT2>
00032 inline
00033 static
00034 void
00035 apply
00036 (
00037 Mat<out_eT>& C,
00038 const Mat<in_eT1>& A,
00039 const Mat<in_eT2>& B,
00040 const out_eT alpha = out_eT(1),
00041 const out_eT beta = out_eT(0)
00042 )
00043 {
00044 arma_extra_debug_sigprint();
00045
00046 const u32 A_n_rows = A.n_rows;
00047 const u32 A_n_cols = A.n_cols;
00048
00049 const u32 B_n_rows = B.n_rows;
00050 const u32 B_n_cols = B.n_cols;
00051
00052 if( (do_trans_A == false) && (do_trans_B == false) )
00053 {
00054 podarray<in_eT1> tmp(A_n_cols);
00055 in_eT1* A_rowdata = tmp.memptr();
00056
00057 for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00058 {
00059
00060 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00061 {
00062 A_rowdata[col_A] = A.at(row_A,col_A);
00063 }
00064
00065 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00066 {
00067 const in_eT2* B_coldata = B.colptr(col_B);
00068
00069 out_eT acc = out_eT(0);
00070 for(u32 i=0; i < B_n_rows; ++i)
00071 {
00072 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00073 }
00074
00075 if( (use_alpha == false) && (use_beta == false) )
00076 {
00077 C.at(row_A,col_B) = acc;
00078 }
00079 else
00080 if( (use_alpha == true) && (use_beta == false) )
00081 {
00082 C.at(row_A,col_B) = alpha * acc;
00083 }
00084 else
00085 if( (use_alpha == false) && (use_beta == true) )
00086 {
00087 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00088 }
00089 else
00090 if( (use_alpha == true) && (use_beta == true) )
00091 {
00092 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00093 }
00094
00095 }
00096 }
00097 }
00098 else
00099 if( (do_trans_A == true) && (do_trans_B == false) )
00100 {
00101 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00102 {
00103
00104
00105 const in_eT1* A_coldata = A.colptr(col_A);
00106
00107 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00108 {
00109 const in_eT2* B_coldata = B.colptr(col_B);
00110
00111 out_eT acc = out_eT(0);
00112 for(u32 i=0; i < B_n_rows; ++i)
00113 {
00114 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00115 }
00116
00117 if( (use_alpha == false) && (use_beta == false) )
00118 {
00119 C.at(col_A,col_B) = acc;
00120 }
00121 else
00122 if( (use_alpha == true) && (use_beta == false) )
00123 {
00124 C.at(col_A,col_B) = alpha * acc;
00125 }
00126 else
00127 if( (use_alpha == false) && (use_beta == true) )
00128 {
00129 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00130 }
00131 else
00132 if( (use_alpha == true) && (use_beta == true) )
00133 {
00134 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00135 }
00136
00137 }
00138 }
00139 }
00140 else
00141 if( (do_trans_A == false) && (do_trans_B == true) )
00142 {
00143 Mat<in_eT2> B_tmp = trans(B);
00144 gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00145 }
00146 else
00147 if( (do_trans_A == true) && (do_trans_B == true) )
00148 {
00149
00150
00151
00152
00153
00154
00155
00156 podarray<in_eT2> tmp(B.n_cols);
00157 in_eT2* B_rowdata = tmp.memptr();
00158
00159 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00160 {
00161
00162 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00163 {
00164 B_rowdata[col_B] = B.at(row_B,col_B);
00165 }
00166
00167 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00168 {
00169 const in_eT1* A_coldata = A.colptr(col_A);
00170
00171 out_eT acc = out_eT(0);
00172 for(u32 i=0; i < A_n_rows; ++i)
00173 {
00174 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00175 }
00176
00177 if( (use_alpha == false) && (use_beta == false) )
00178 {
00179 C.at(col_A,row_B) = acc;
00180 }
00181 else
00182 if( (use_alpha == true) && (use_beta == false) )
00183 {
00184 C.at(col_A,row_B) = alpha * acc;
00185 }
00186 else
00187 if( (use_alpha == false) && (use_beta == true) )
00188 {
00189 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00190 }
00191 else
00192 if( (use_alpha == true) && (use_beta == true) )
00193 {
00194 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00195 }
00196
00197 }
00198 }
00199
00200 }
00201 }
00202
00203 };
00204
00205
00206
00207
00208
00209
00210 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00211 class gemm_mixed_simple
00212 {
00213 public:
00214
00215 template<typename out_eT, typename in_eT1, typename in_eT2>
00216 inline
00217 static
00218 void
00219 apply
00220 (
00221 Mat<out_eT>& C,
00222 const Mat<in_eT1>& A,
00223 const Mat<in_eT2>& B,
00224 const out_eT alpha = out_eT(1),
00225 const out_eT beta = out_eT(0)
00226 )
00227 {
00228 arma_extra_debug_sigprint();
00229
00230 const u32 A_n_rows = A.n_rows;
00231 const u32 A_n_cols = A.n_cols;
00232
00233 const u32 B_n_rows = B.n_rows;
00234 const u32 B_n_cols = B.n_cols;
00235
00236 if( (do_trans_A == false) && (do_trans_B == false) )
00237 {
00238 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00239 {
00240 for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00241 {
00242 const in_eT2* B_coldata = B.colptr(col_B);
00243
00244 out_eT acc = out_eT(0);
00245 for(u32 i = 0; i < B_n_rows; ++i)
00246 {
00247 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00248 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00249 acc += val1 * val2;
00250
00251 }
00252
00253 if( (use_alpha == false) && (use_beta == false) )
00254 {
00255 C.at(row_A,col_B) = acc;
00256 }
00257 else
00258 if( (use_alpha == true) && (use_beta == false) )
00259 {
00260 C.at(row_A,col_B) = alpha * acc;
00261 }
00262 else
00263 if( (use_alpha == false) && (use_beta == true) )
00264 {
00265 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00266 }
00267 else
00268 if( (use_alpha == true) && (use_beta == true) )
00269 {
00270 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00271 }
00272 }
00273 }
00274 }
00275 else
00276 if( (do_trans_A == true) && (do_trans_B == false) )
00277 {
00278 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00279 {
00280
00281
00282 const in_eT1* A_coldata = A.colptr(col_A);
00283
00284 for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00285 {
00286 const in_eT2* B_coldata = B.colptr(col_B);
00287
00288 out_eT acc = out_eT(0);
00289 for(u32 i=0; i < B_n_rows; ++i)
00290 {
00291 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00292 }
00293
00294 if( (use_alpha == false) && (use_beta == false) )
00295 {
00296 C.at(col_A,col_B) = acc;
00297 }
00298 else
00299 if( (use_alpha == true) && (use_beta == false) )
00300 {
00301 C.at(col_A,col_B) = alpha * acc;
00302 }
00303 else
00304 if( (use_alpha == false) && (use_beta == true) )
00305 {
00306 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00307 }
00308 else
00309 if( (use_alpha == true) && (use_beta == true) )
00310 {
00311 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00312 }
00313
00314 }
00315 }
00316 }
00317 else
00318 if( (do_trans_A == false) && (do_trans_B == true) )
00319 {
00320 for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00321 {
00322 for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00323 {
00324 out_eT acc = out_eT(0);
00325 for(u32 i = 0; i < B_n_cols; ++i)
00326 {
00327 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00328 }
00329
00330 if( (use_alpha == false) && (use_beta == false) )
00331 {
00332 C.at(row_A,row_B) = acc;
00333 }
00334 else
00335 if( (use_alpha == true) && (use_beta == false) )
00336 {
00337 C.at(row_A,row_B) = alpha * acc;
00338 }
00339 else
00340 if( (use_alpha == false) && (use_beta == true) )
00341 {
00342 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00343 }
00344 else
00345 if( (use_alpha == true) && (use_beta == true) )
00346 {
00347 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00348 }
00349 }
00350 }
00351 }
00352 else
00353 if( (do_trans_A == true) && (do_trans_B == true) )
00354 {
00355 for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00356 {
00357
00358 for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00359 {
00360 const in_eT1* A_coldata = A.colptr(col_A);
00361
00362 out_eT acc = out_eT(0);
00363 for(u32 i=0; i < A_n_rows; ++i)
00364 {
00365 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00366 }
00367
00368 if( (use_alpha == false) && (use_beta == false) )
00369 {
00370 C.at(col_A,row_B) = acc;
00371 }
00372 else
00373 if( (use_alpha == true) && (use_beta == false) )
00374 {
00375 C.at(col_A,row_B) = alpha * acc;
00376 }
00377 else
00378 if( (use_alpha == false) && (use_beta == true) )
00379 {
00380 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00381 }
00382 else
00383 if( (use_alpha == true) && (use_beta == true) )
00384 {
00385 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00386 }
00387
00388 }
00389 }
00390
00391 }
00392 }
00393
00394 };
00395
00396
00397
00398
00399
00400
00401
00402
00403 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00404 class gemm_mixed
00405 {
00406 public:
00407
00408
00409 template<typename out_eT, typename in_eT1, typename in_eT2>
00410 inline
00411 static
00412 void
00413 apply
00414 (
00415 Mat<out_eT>& C,
00416 const Mat<in_eT1>& A,
00417 const Mat<in_eT2>& B,
00418 const out_eT alpha = out_eT(1),
00419 const out_eT beta = out_eT(0)
00420 )
00421 {
00422 arma_extra_debug_sigprint();
00423
00424 if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00425 {
00426 gemm_mixed_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00427 }
00428 else
00429 {
00430 gemm_mixed_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00431 }
00432 }
00433
00434 };
00435
00436
00437
00438