gemv.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup gemv
00017 //! @{
00018 
00019 
00020 
00021 //! \brief
00022 //! Partial emulation of ATLAS/BLAS gemv().
00023 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
00024 
00025 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00026 class gemv_arma
00027   {
00028   public:
00029   
00030   template<typename eT>
00031   inline
00032   static
00033   void
00034   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00035     {
00036     arma_extra_debug_sigprint();
00037     
00038     const u32 A_n_rows = A.n_rows;
00039     const u32 A_n_cols = A.n_cols;
00040     
00041     if(do_trans_A == false)
00042       {
00043       for(u32 row=0; row < A_n_rows; ++row)
00044         {
00045         
00046         eT acc = eT(0);
00047         for(u32 col=0; col < A_n_cols; ++col)
00048           {
00049           acc += A.at(row,col) * x[col];
00050           }
00051           
00052         if( (use_alpha == false) && (use_beta == false) )
00053           {
00054           y[row] = acc;
00055           }
00056         else
00057         if( (use_alpha == true) && (use_beta == false) )
00058           {
00059           y[row] = alpha * acc;
00060           }
00061         else
00062         if( (use_alpha == false) && (use_beta == true) )
00063           {
00064           y[row] = acc + beta*y[row];
00065           }
00066         else
00067         if( (use_alpha == true) && (use_beta == true) )
00068           {
00069           y[row] = alpha*acc + beta*y[row];
00070           }
00071         }
00072       }
00073     else
00074     if(do_trans_A == true)
00075       {
00076       for(u32 col=0; col < A_n_cols; ++col)
00077         {
00078         // col is interpreted as row when storing the results in 'y'
00079         
00080         const eT* A_coldata = A.colptr(col);
00081         
00082         eT acc = eT(0);
00083         for(u32 row=0; row < A_n_rows; ++row)
00084           {
00085           acc += A_coldata[row] * x[row];
00086           }
00087       
00088         if( (use_alpha == false) && (use_beta == false) )
00089           {
00090           y[col] = acc;
00091           }
00092         else
00093         if( (use_alpha == true) && (use_beta == false) )
00094           {
00095           y[col] = alpha * acc;
00096           }
00097         else
00098         if( (use_alpha == false) && (use_beta == true) )
00099           {
00100           y[col] = acc + beta*y[col];
00101           }
00102         else
00103         if( (use_alpha == true) && (use_beta == true) )
00104           {
00105           y[col] = alpha*acc + beta*y[col];
00106           }
00107         
00108         }
00109       }
00110     }
00111     
00112   };
00113 
00114 
00115 
00116 //! \brief
00117 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv.
00118 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
00119 
00120 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00121 class gemv
00122   {
00123   public:
00124   
00125   template<typename eT>
00126   inline
00127   static
00128   void
00129   apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00130     {
00131     arma_extra_debug_sigprint();
00132     
00133     if( (A.n_elem <= 256u) || (is_supported_blas_type<eT>::value == false) )
00134      {
00135      gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00136      }
00137     else
00138       {
00139       #if defined(ARMA_USE_ATLAS)
00140         {
00141         arma_extra_debug_print("atlas::cblas_gemv()");
00142         
00143         atlas::cblas_gemv<eT>
00144           (
00145           atlas::CblasColMajor,
00146           (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00147           A.n_rows,
00148           A.n_cols,
00149           (use_alpha) ? alpha : eT(1),
00150           A.mem,
00151           A.n_rows,
00152           x,
00153           1,
00154           (use_beta) ? beta : eT(0),
00155           y,
00156           1
00157           );
00158         }
00159       #elif defined(ARMA_USE_BLAS)
00160         {
00161         arma_extra_debug_print("blas::gemv_()");
00162         
00163         const char trans_A     = (do_trans_A) ? 'T' : 'N';
00164         const int  m           = A.n_rows;
00165         const int  n           = A.n_cols;
00166         const eT   local_alpha = (use_alpha) ? alpha : eT(1);
00167         //const int  lda         = A.n_rows;
00168         const int  inc         = 1;
00169         const eT   local_beta  = (use_beta) ? beta : eT(0);
00170         
00171         arma_extra_debug_print( arma_boost::format("blas::gemv_(): trans_A = %c") % trans_A );
00172 
00173         blas::gemv_<eT>
00174           (
00175           &trans_A,
00176           &m,
00177           &n,
00178           &local_alpha,
00179           A.mem,
00180           &m,  // lda
00181           x,
00182           &inc,
00183           &local_beta,
00184           y,
00185           &inc
00186           );
00187         }
00188       #else
00189         {
00190         gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00191         }
00192       #endif
00193       }
00194     
00195     }
00196   
00197   };
00198 
00199 
00200 //! @}