gemm_mixed.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup gemm_mixed
00017 //! @{
00018 
00019 
00020 
00021 //! \brief
00022 //! Matrix multplication where the matrices have different element types.
00023 //! Uses caching for speedup.
00024 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00025 
00026 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00027 class gemm_mixed_cache
00028   {
00029   public:
00030   
00031   template<typename out_eT, typename in_eT1, typename in_eT2>
00032   inline
00033   static
00034   void
00035   apply
00036     (
00037           Mat<out_eT>& C,
00038     const Mat<in_eT1>& A,
00039     const Mat<in_eT2>& B,
00040     const out_eT alpha = out_eT(1),
00041     const out_eT beta  = out_eT(0)
00042     )
00043     {
00044     arma_extra_debug_sigprint();
00045     
00046     const u32 A_n_rows = A.n_rows;
00047     const u32 A_n_cols = A.n_cols;
00048     
00049     const u32 B_n_rows = B.n_rows;
00050     const u32 B_n_cols = B.n_cols;
00051     
00052     if( (do_trans_A == false) && (do_trans_B == false) )
00053       {
00054       podarray<in_eT1> tmp(A_n_cols);
00055       in_eT1* A_rowdata = tmp.memptr();
00056       
00057       for(u32 row_A=0; row_A < A_n_rows; ++row_A)
00058         {
00059         
00060         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00061           {
00062           A_rowdata[col_A] = A.at(row_A,col_A);
00063           }
00064         
00065         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00066           {
00067           const in_eT2* B_coldata = B.colptr(col_B);
00068           
00069           out_eT acc = out_eT(0);
00070           for(u32 i=0; i < B_n_rows; ++i)
00071             {
00072             acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00073             }
00074         
00075           if( (use_alpha == false) && (use_beta == false) )
00076             {
00077             C.at(row_A,col_B) = acc;
00078             }
00079           else
00080           if( (use_alpha == true) && (use_beta == false) )
00081             {
00082             C.at(row_A,col_B) = alpha * acc;
00083             }
00084           else
00085           if( (use_alpha == false) && (use_beta == true) )
00086             {
00087             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00088             }
00089           else
00090           if( (use_alpha == true) && (use_beta == true) )
00091             {
00092             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00093             }
00094           
00095           }
00096         }
00097       }
00098     else
00099     if( (do_trans_A == true) && (do_trans_B == false) )
00100       {
00101       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00102         {
00103         // col_A is interpreted as row_A when storing the results in matrix C
00104         
00105         const in_eT1* A_coldata = A.colptr(col_A);
00106         
00107         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00108           {
00109           const in_eT2* B_coldata = B.colptr(col_B);
00110           
00111           out_eT acc = out_eT(0);
00112           for(u32 i=0; i < B_n_rows; ++i)
00113             {
00114             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00115             }
00116         
00117           if( (use_alpha == false) && (use_beta == false) )
00118             {
00119             C.at(col_A,col_B) = acc;
00120             }
00121           else
00122           if( (use_alpha == true) && (use_beta == false) )
00123             {
00124             C.at(col_A,col_B) = alpha * acc;
00125             }
00126           else
00127           if( (use_alpha == false) && (use_beta == true) )
00128             {
00129             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00130             }
00131           else
00132           if( (use_alpha == true) && (use_beta == true) )
00133             {
00134             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00135             }
00136           
00137           }
00138         }
00139       }
00140     else
00141     if( (do_trans_A == false) && (do_trans_B == true) )
00142       {
00143       Mat<in_eT2> B_tmp = trans(B);
00144       gemm_mixed_cache<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00145       }
00146     else
00147     if( (do_trans_A == true) && (do_trans_B == true) )
00148       {
00149       // mat B_tmp = trans(B);
00150       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
00151       
00152       
00153       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
00154       // transpose operations are not needed
00155       
00156       podarray<in_eT2> tmp(B.n_cols);
00157       in_eT2* B_rowdata = tmp.memptr();
00158       
00159       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00160         {
00161         
00162         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00163           {
00164           B_rowdata[col_B] = B.at(row_B,col_B);
00165           }
00166         
00167         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00168           {
00169           const in_eT1* A_coldata = A.colptr(col_A);
00170           
00171           out_eT acc = out_eT(0);
00172           for(u32 i=0; i < A_n_rows; ++i)
00173             {
00174             acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00175             }
00176         
00177           if( (use_alpha == false) && (use_beta == false) )
00178             {
00179             C.at(col_A,row_B) = acc;
00180             }
00181           else
00182           if( (use_alpha == true) && (use_beta == false) )
00183             {
00184             C.at(col_A,row_B) = alpha * acc;
00185             }
00186           else
00187           if( (use_alpha == false) && (use_beta == true) )
00188             {
00189             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00190             }
00191           else
00192           if( (use_alpha == true) && (use_beta == true) )
00193             {
00194             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00195             }
00196           
00197           }
00198         }
00199       
00200       }
00201     }
00202     
00203   };
00204 
00205 
00206 
00207 //! Matrix multplication where the matrices have different element types.
00208 //! Simple version (no caching).
00209 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
00210 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00211 class gemm_mixed_simple
00212   {
00213   public:
00214   
00215   template<typename out_eT, typename in_eT1, typename in_eT2>
00216   inline
00217   static
00218   void
00219   apply
00220     (
00221           Mat<out_eT>& C,
00222     const Mat<in_eT1>& A,
00223     const Mat<in_eT2>& B,
00224     const out_eT alpha = out_eT(1),
00225     const out_eT beta  = out_eT(0)
00226     )
00227     {
00228     arma_extra_debug_sigprint();
00229     
00230     const u32 A_n_rows = A.n_rows;
00231     const u32 A_n_cols = A.n_cols;
00232     
00233     const u32 B_n_rows = B.n_rows;
00234     const u32 B_n_cols = B.n_cols;
00235     
00236     if( (do_trans_A == false) && (do_trans_B == false) )
00237       {
00238       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00239         {
00240         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00241           {
00242           const in_eT2* B_coldata = B.colptr(col_B);
00243           
00244           out_eT acc = out_eT(0);
00245           for(u32 i = 0; i < B_n_rows; ++i)
00246             {
00247             const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00248             const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00249             acc += val1 * val2;
00250             //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00251             }
00252           
00253           if( (use_alpha == false) && (use_beta == false) )
00254             {
00255             C.at(row_A,col_B) = acc;
00256             }
00257           else
00258           if( (use_alpha == true) && (use_beta == false) )
00259             {
00260             C.at(row_A,col_B) = alpha * acc;
00261             }
00262           else
00263           if( (use_alpha == false) && (use_beta == true) )
00264             {
00265             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00266             }
00267           else
00268           if( (use_alpha == true) && (use_beta == true) )
00269             {
00270             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00271             }
00272           }
00273         }
00274       }
00275     else
00276     if( (do_trans_A == true) && (do_trans_B == false) )
00277       {
00278       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00279         {
00280         // col_A is interpreted as row_A when storing the results in matrix C
00281         
00282         const in_eT1* A_coldata = A.colptr(col_A);
00283         
00284         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00285           {
00286           const in_eT2* B_coldata = B.colptr(col_B);
00287           
00288           out_eT acc = out_eT(0);
00289           for(u32 i=0; i < B_n_rows; ++i)
00290             {
00291             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00292             }
00293         
00294           if( (use_alpha == false) && (use_beta == false) )
00295             {
00296             C.at(col_A,col_B) = acc;
00297             }
00298           else
00299           if( (use_alpha == true) && (use_beta == false) )
00300             {
00301             C.at(col_A,col_B) = alpha * acc;
00302             }
00303           else
00304           if( (use_alpha == false) && (use_beta == true) )
00305             {
00306             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00307             }
00308           else
00309           if( (use_alpha == true) && (use_beta == true) )
00310             {
00311             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00312             }
00313           
00314           }
00315         }
00316       }
00317     else
00318     if( (do_trans_A == false) && (do_trans_B == true) )
00319       {
00320       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00321         {
00322         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00323           {
00324           out_eT acc = out_eT(0);
00325           for(u32 i = 0; i < B_n_cols; ++i)
00326             {
00327             acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00328             }
00329           
00330           if( (use_alpha == false) && (use_beta == false) )
00331             {
00332             C.at(row_A,row_B) = acc;
00333             }
00334           else
00335           if( (use_alpha == true) && (use_beta == false) )
00336             {
00337             C.at(row_A,row_B) = alpha * acc;
00338             }
00339           else
00340           if( (use_alpha == false) && (use_beta == true) )
00341             {
00342             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00343             }
00344           else
00345           if( (use_alpha == true) && (use_beta == true) )
00346             {
00347             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00348             }
00349           }
00350         }
00351       }
00352     else
00353     if( (do_trans_A == true) && (do_trans_B == true) )
00354       {
00355       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00356         {
00357         
00358         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00359           {
00360           const in_eT1* A_coldata = A.colptr(col_A);
00361           
00362           out_eT acc = out_eT(0);
00363           for(u32 i=0; i < A_n_rows; ++i)
00364             {
00365             acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00366             }
00367         
00368           if( (use_alpha == false) && (use_beta == false) )
00369             {
00370             C.at(col_A,row_B) = acc;
00371             }
00372           else
00373           if( (use_alpha == true) && (use_beta == false) )
00374             {
00375             C.at(col_A,row_B) = alpha * acc;
00376             }
00377           else
00378           if( (use_alpha == false) && (use_beta == true) )
00379             {
00380             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00381             }
00382           else
00383           if( (use_alpha == true) && (use_beta == true) )
00384             {
00385             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00386             }
00387           
00388           }
00389         }
00390       
00391       }
00392     }
00393     
00394   };
00395 
00396 
00397 
00398 
00399 
00400 //! \brief
00401 //! Matrix multplication where the matrices have different element types.
00402 
00403 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
00404 class gemm_mixed
00405   {
00406   public:
00407   
00408   //! immediate multiplication of matrices A and B, storing the result in C
00409   template<typename out_eT, typename in_eT1, typename in_eT2>
00410   inline
00411   static
00412   void
00413   apply
00414     (
00415           Mat<out_eT>& C,
00416     const Mat<in_eT1>& A,
00417     const Mat<in_eT2>& B,
00418     const out_eT alpha = out_eT(1),
00419     const out_eT beta  = out_eT(0)
00420     )
00421     {
00422     arma_extra_debug_sigprint();
00423     
00424     if( (A.n_elem <= 64u) && (B.n_elem <= 64u) )
00425       {
00426       gemm_mixed_simple<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00427       }
00428     else
00429       {
00430       gemm_mixed_cache<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
00431       }
00432     }
00433   
00434   };
00435 
00436 
00437 
00438 //! @}