00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00026 class gemv_arma
00027 {
00028 public:
00029
00030 template<typename eT>
00031 inline
00032 static
00033 void
00034 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00035 {
00036 arma_extra_debug_sigprint();
00037
00038 const u32 A_n_rows = A.n_rows;
00039 const u32 A_n_cols = A.n_cols;
00040
00041 if(do_trans_A == false)
00042 {
00043 for(u32 row=0; row < A_n_rows; ++row)
00044 {
00045
00046 eT acc = eT(0);
00047 for(u32 col=0; col < A_n_cols; ++col)
00048 {
00049 acc += A.at(row,col) * x[col];
00050 }
00051
00052 if( (use_alpha == false) && (use_beta == false) )
00053 {
00054 y[row] = acc;
00055 }
00056 else
00057 if( (use_alpha == true) && (use_beta == false) )
00058 {
00059 y[row] = alpha * acc;
00060 }
00061 else
00062 if( (use_alpha == false) && (use_beta == true) )
00063 {
00064 y[row] = acc + beta*y[row];
00065 }
00066 else
00067 if( (use_alpha == true) && (use_beta == true) )
00068 {
00069 y[row] = alpha*acc + beta*y[row];
00070 }
00071 }
00072 }
00073 else
00074 if(do_trans_A == true)
00075 {
00076 for(u32 col=0; col < A_n_cols; ++col)
00077 {
00078
00079
00080 const eT* A_coldata = A.colptr(col);
00081
00082 eT acc = eT(0);
00083 for(u32 row=0; row < A_n_rows; ++row)
00084 {
00085 acc += A_coldata[row] * x[row];
00086 }
00087
00088 if( (use_alpha == false) && (use_beta == false) )
00089 {
00090 y[col] = acc;
00091 }
00092 else
00093 if( (use_alpha == true) && (use_beta == false) )
00094 {
00095 y[col] = alpha * acc;
00096 }
00097 else
00098 if( (use_alpha == false) && (use_beta == true) )
00099 {
00100 y[col] = acc + beta*y[col];
00101 }
00102 else
00103 if( (use_alpha == true) && (use_beta == true) )
00104 {
00105 y[col] = alpha*acc + beta*y[col];
00106 }
00107
00108 }
00109 }
00110 }
00111
00112 };
00113
00114
00115
00116
00117
00118
00119
00120 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
00121 class gemv
00122 {
00123 public:
00124
00125 template<typename eT>
00126 inline
00127 static
00128 void
00129 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
00130 {
00131 arma_extra_debug_sigprint();
00132
00133 if( (A.n_elem <= 256u) || (is_supported_blas_type<eT>::value == false) )
00134 {
00135 gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00136 }
00137 else
00138 {
00139 #if defined(ARMA_USE_ATLAS)
00140 {
00141 arma_extra_debug_print("atlas::cblas_gemv()");
00142
00143 atlas::cblas_gemv<eT>
00144 (
00145 atlas::CblasColMajor,
00146 (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans,
00147 A.n_rows,
00148 A.n_cols,
00149 (use_alpha) ? alpha : eT(1),
00150 A.mem,
00151 A.n_rows,
00152 x,
00153 1,
00154 (use_beta) ? beta : eT(0),
00155 y,
00156 1
00157 );
00158 }
00159 #elif defined(ARMA_USE_BLAS)
00160 {
00161 arma_extra_debug_print("blas::gemv_()");
00162
00163 const char trans_A = (do_trans_A) ? 'T' : 'N';
00164 const int m = A.n_rows;
00165 const int n = A.n_cols;
00166 const eT local_alpha = (use_alpha) ? alpha : eT(1);
00167
00168 const int inc = 1;
00169 const eT local_beta = (use_beta) ? beta : eT(0);
00170
00171 arma_extra_debug_print( arma_boost::format("blas::gemv_(): trans_A = %c") % trans_A );
00172
00173 blas::gemv_<eT>
00174 (
00175 &trans_A,
00176 &m,
00177 &n,
00178 &local_alpha,
00179 A.mem,
00180 &m,
00181 x,
00182 &inc,
00183 &local_beta,
00184 y,
00185 &inc
00186 );
00187 }
00188 #else
00189 {
00190 gemv_arma<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
00191 }
00192 #endif
00193 }
00194
00195 }
00196
00197 };
00198
00199
00200