gemm.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup gemm
00017 //! @{
00018 
00019 
00020 
00021 //! \brief
00022 //! Partial emulation of ATLAS/BLAS gemm(), using caching for speedup.
00023 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00024 
00025 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00026 class gemm_emul_cache
00027   {
00028   public:
00029   
00030   template<typename eT>
00031   inline
00032   static
00033   void
00034   apply
00035     (
00036           Mat<eT>& C,
00037     const Mat<eT>& A,
00038     const Mat<eT>& B,
00039     const eT alpha = eT(1),
00040     const eT beta  = eT(0)
00041     )
00042     {
00043     arma_extra_debug_sigprint();
00044 
00045     const u32 A_n_rows = A.n_rows;
00046     const u32 A_n_cols = A.n_cols;
00047     
00048     const u32 B_n_rows = B.n_rows;
00049     const u32 B_n_cols = B.n_cols;
00050     
00051     if( (do_trans_A == false) && (do_trans_B == false) )
00052       {
00053       arma_aligned podarray<eT> tmp(A_n_cols);
00054       eT* A_rowdata = tmp.memptr();
00055       
00056       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00057         {
00058         
00059         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00060           {
00061           A_rowdata[col_A] = A.at(row_A,col_A);
00062           }
00063         
00064         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00065           {
00066           const eT* B_coldata = B.colptr(col_B);
00067           
00068           eT acc = eT(0);
00069           for(u32 i=0; i < B_n_rows; ++i)
00070             {
00071             acc += A_rowdata[i] * B_coldata[i];
00072             }
00073         
00074           if( (use_alpha == false) && (use_beta == false) )
00075             {
00076             C.at(row_A,col_B) = acc;
00077             }
00078           else
00079           if( (use_alpha == true) && (use_beta == false) )
00080             {
00081             C.at(row_A,col_B) = alpha * acc;
00082             }
00083           else
00084           if( (use_alpha == false) && (use_beta == true) )
00085             {
00086             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00087             }
00088           else
00089           if( (use_alpha == true) && (use_beta == true) )
00090             {
00091             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00092             }
00093           
00094           }
00095         }
00096       }
00097     else
00098     if( (do_trans_A == true) && (do_trans_B == false) )
00099       {
00100       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00101         {
00102         // col_A is interpreted as row_A when storing the results in matrix C
00103         
00104         const eT* A_coldata = A.colptr(col_A);
00105         
00106         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00107           {
00108           const eT* B_coldata = B.colptr(col_B);
00109           
00110           eT acc = eT(0);
00111           for(u32 i=0; i < B_n_rows; ++i)
00112             {
00113             acc += A_coldata[i] * B_coldata[i];
00114             }
00115         
00116           if( (use_alpha == false) && (use_beta == false) )
00117             {
00118             C.at(col_A,col_B) = acc;
00119             }
00120           else
00121           if( (use_alpha == true) && (use_beta == false) )
00122             {
00123             C.at(col_A,col_B) = alpha * acc;
00124             }
00125           else
00126           if( (use_alpha == false) && (use_beta == true) )
00127             {
00128             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00129             }
00130           else
00131           if( (use_alpha == true) && (use_beta == true) )
00132             {
00133             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00134             }
00135           
00136           }
00137         }
00138       }
00139     else
00140     if( (do_trans_A == false) && (do_trans_B == true) )
00141       {
00142       Mat<eT> B_tmp = trans(B);
00143       gemm_emul_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00144       }
00145     else
00146     if( (do_trans_A == true) && (do_trans_B == true) )
00147       {
00148       // mat B_tmp = trans(B);
00149       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00150       
00151       
00152       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00153       // transpose operations are not needed
00154       
00155       arma_aligned podarray<eT> tmp(B.n_cols);
00156       eT* B_rowdata = tmp.memptr();
00157       
00158       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00159         {
00160         
00161         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00162           {
00163           B_rowdata[col_B] = B.at(row_B,col_B);
00164           }
00165         
00166         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00167           {
00168           const eT* A_coldata = A.colptr(col_A);
00169           
00170           eT acc = eT(0);
00171           for(u32 i=0; i < A_n_rows; ++i)
00172             {
00173             acc += B_rowdata[i] * A_coldata[i];
00174             }
00175         
00176           if( (use_alpha == false) && (use_beta == false) )
00177             {
00178             C.at(col_A,row_B) = acc;
00179             }
00180           else
00181           if( (use_alpha == true) && (use_beta == false) )
00182             {
00183             C.at(col_A,row_B) = alpha * acc;
00184             }
00185           else
00186           if( (use_alpha == false) && (use_beta == true) )
00187             {
00188             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00189             }
00190           else
00191           if( (use_alpha == true) && (use_beta == true) )
00192             {
00193             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00194             }
00195           
00196           }
00197         }
00198       
00199       }
00200     }
00201     
00202   };
00203 
00204 
00205 
00206 //! Partial emulation of ATLAS/BLAS gemm(), non-cached version.
00207 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00208 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00209 class gemm_emul_simple
00210   {
00211   public:
00212   
00213   template<typename eT>
00214   inline
00215   static
00216   void
00217   apply
00218     (
00219           Mat<eT>& C,
00220     const Mat<eT>& A,
00221     const Mat<eT>& B,
00222     const eT alpha = eT(1),
00223     const eT beta  = eT(0)
00224     )
00225     {
00226     arma_extra_debug_sigprint();
00227     
00228     const u32 A_n_rows = A.n_rows;
00229     const u32 A_n_cols = A.n_cols;
00230     
00231     const u32 B_n_rows = B.n_rows;
00232     const u32 B_n_cols = B.n_cols;
00233     
00234     if( (do_trans_A == false) && (do_trans_B == false) )
00235       {
00236       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00237         {
00238         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00239           {
00240           const eT* B_coldata = B.colptr(col_B);
00241           
00242           eT acc = eT(0);
00243           for(u32 i = 0; i < B_n_rows; ++i)
00244             {
00245             acc += A.at(row_A,i) * B_coldata[i];
00246             }
00247           
00248           if( (use_alpha == false) && (use_beta == false) )
00249             {
00250             C.at(row_A,col_B) = acc;
00251             }
00252           else
00253           if( (use_alpha == true) && (use_beta == false) )
00254             {
00255             C.at(row_A,col_B) = alpha * acc;
00256             }
00257           else
00258           if( (use_alpha == false) && (use_beta == true) )
00259             {
00260             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00261             }
00262           else
00263           if( (use_alpha == true) && (use_beta == true) )
00264             {
00265             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00266             }
00267           }
00268         }
00269       }
00270     else
00271     if( (do_trans_A == true) && (do_trans_B == false) )
00272       {
00273       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00274         {
00275         // col_A is interpreted as row_A when storing the results in matrix C
00276         
00277         const eT* A_coldata = A.colptr(col_A);
00278         
00279         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00280           {
00281           const eT* B_coldata = B.colptr(col_B);
00282           
00283           eT acc = eT(0);
00284           for(u32 i=0; i < B_n_rows; ++i)
00285             {
00286             acc += A_coldata[i] * B_coldata[i];
00287             }
00288         
00289           if( (use_alpha == false) && (use_beta == false) )
00290             {
00291             C.at(col_A,col_B) = acc;
00292             }
00293           else
00294           if( (use_alpha == true) && (use_beta == false) )
00295             {
00296             C.at(col_A,col_B) = alpha * acc;
00297             }
00298           else
00299           if( (use_alpha == false) && (use_beta == true) )
00300             {
00301             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00302             }
00303           else
00304           if( (use_alpha == true) && (use_beta == true) )
00305             {
00306             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00307             }
00308           
00309           }
00310         }
00311       }
00312     else
00313     if( (do_trans_A == false) && (do_trans_B == true) )
00314       {
00315       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00316         {
00317         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00318           {
00319           eT acc = eT(0);
00320           for(u32 i = 0; i < B_n_cols; ++i)
00321             {
00322             acc += A.at(row_A,i) * B.at(row_B,i);
00323             }
00324           
00325           if( (use_alpha == false) && (use_beta == false) )
00326             {
00327             C.at(row_A,row_B) = acc;
00328             }
00329           else
00330           if( (use_alpha == true) && (use_beta == false) )
00331             {
00332             C.at(row_A,row_B) = alpha * acc;
00333             }
00334           else
00335           if( (use_alpha == false) && (use_beta == true) )
00336             {
00337             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00338             }
00339           else
00340           if( (use_alpha == true) && (use_beta == true) )
00341             {
00342             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00343             }
00344           }
00345         }
00346       }
00347     else
00348     if( (do_trans_A == true) && (do_trans_B == true) )
00349       {
00350       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00351         {
00352         
00353         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00354           {
00355           const eT* A_coldata = A.colptr(col_A);
00356           
00357           eT acc = eT(0);
00358           for(u32 i=0; i < A_n_rows; ++i)
00359             {
00360             acc += B.at(row_B,i) * A_coldata[i];
00361             }
00362         
00363           if( (use_alpha == false) && (use_beta == false) )
00364             {
00365             C.at(col_A,row_B) = acc;
00366             }
00367           else
00368           if( (use_alpha == true) && (use_beta == false) )
00369             {
00370             C.at(col_A,row_B) = alpha * acc;
00371             }
00372           else
00373           if( (use_alpha == false) && (use_beta == true) )
00374             {
00375             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00376             }
00377           else
00378           if( (use_alpha == true) && (use_beta == true) )
00379             {
00380             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00381             }
00382           
00383           }
00384         }
00385       
00386       }
00387     }
00388     
00389   };
00390 
00391 
00392 
00393 //! \brief
00394 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
00395 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00396 
00397 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00398 class gemm
00399   {
00400   public:
00401   
00402   template<typename eT>
00403   inline
00404   static
00405   void
00406   apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00407     {
00408     arma_extra_debug_sigprint();
00409     
00410     if( ((A.n_elem <= 64u) && (B.n_elem <= 64u)) )
00411       {
00412       gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00413       }
00414     else
00415       {
00416       #if defined(ARMA_USE_ATLAS)
00417         {
00418         arma_extra_debug_print("atlas::cblas_gemm()");
00419         
00420         atlas::cblas_gemm<eT>
00421           (
00422           atlas::CblasColMajor,
00423           (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00424           (do_trans_B) ? atlas::CblasTrans : atlas::CblasNoTrans,
00425           C.n_rows,
00426           C.n_cols,
00427           (do_trans_A) ? A.n_rows : A.n_cols,
00428           (use_alpha) ? alpha : eT(1),
00429           A.mem,
00430           (do_trans_A) ? A.n_rows : C.n_rows,
00431           B.mem,
00432           (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
00433           (use_beta) ? beta : eT(0),
00434           C.memptr(),
00435           C.n_rows
00436           );
00437         }
00438       #elif defined(ARMA_USE_BLAS)
00439         {
00440         arma_extra_debug_print("blas::gemm_()");
00441         
00442         const char trans_A = (do_trans_A) ? 'T' : 'N';
00443         const char trans_B = (do_trans_B) ? 'T' : 'N';
00444         
00445         const int m   = C.n_rows;
00446         const int n   = C.n_cols;
00447         const int k   = (do_trans_A) ? A.n_rows : A.n_cols;
00448         
00449         const eT local_alpha = (use_alpha) ? alpha : eT(1);
00450         
00451         const int lda = (do_trans_A) ? k : m;
00452         const int ldb = (do_trans_B) ? n : k;
00453         
00454         const eT local_beta  = (use_beta) ? beta : eT(0);
00455         
00456         arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_A = %c") % trans_A );
00457         arma_extra_debug_print( arma_boost::format("blas::gemm_(): trans_B = %c") % trans_B );
00458         
00459         blas::gemm_<eT>
00460           (
00461           &trans_A,
00462           &trans_B,
00463           &m,
00464           &n,
00465           &k,
00466           &local_alpha,
00467           A.mem,
00468           &lda,
00469           B.mem,
00470           &ldb,
00471           &local_beta,
00472           C.memptr(),
00473           &m
00474           );
00475         }
00476       #else
00477         {
00478         gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00479         }
00480       #endif
00481       }
00482     }
00483   
00484   
00485   
00486   //! immediate multiplication of matrices A and B, storing the result in C
00487   template<typename eT>
00488   inline
00489   static
00490   void
00491   apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
00492     {
00493     if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00494       {
00495       gemm_emul_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00496       }
00497     else
00498       {
00499       gemm_emul_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00500       }
00501     }
00502   
00503   
00504   
00505   inline
00506   static
00507   void
00508   apply
00509     (
00510           Mat<float>& C,
00511     const Mat<float>& A,
00512     const Mat<float>& B,
00513     const float alpha = float(1),
00514     const float beta  = float(0)
00515     )
00516     {
00517     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00518     }
00519   
00520   
00521   
00522   inline
00523   static
00524   void
00525   apply
00526     (
00527           Mat<double>& C,
00528     const Mat<double>& A,
00529     const Mat<double>& B,
00530     const double alpha = double(1),
00531     const double beta  = double(0)
00532     )
00533     {
00534     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00535     }
00536   
00537   
00538   
00539   inline
00540   static
00541   void
00542   apply
00543     (
00544           Mat< std::complex<float> >& C,
00545     const Mat< std::complex<float> >& A,
00546     const Mat< std::complex<float> >& B,
00547     const std::complex<float> alpha = std::complex<float>(1),
00548     const std::complex<float> beta  = std::complex<float>(0)
00549     )
00550     {
00551     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00552     }
00553   
00554   
00555   
00556   inline
00557   static
00558   void
00559   apply
00560     (
00561           Mat< std::complex<double> >& C,
00562     const Mat< std::complex<double> >& A,
00563     const Mat< std::complex<double> >& B,
00564     const std::complex<double> alpha = std::complex<double>(1),
00565     const std::complex<double> beta  = std::complex<double>(0)
00566     )
00567     {
00568     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
00569     }
00570   
00571   };
00572 
00573 
00574 
00575 //! @}