glue_times_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup glue_times
00017 //! @{
00018 
00019 
00020 template<typename eT>
00021 arma_inline
00022 u32 glue_times::mul_storage_cost(const Mat<eT>& X, const Mat<eT>& Y)
00023   {
00024   return X.n_rows * Y.n_cols;
00025   }
00026 
00027 
00028 
00029 //! multiply matrices A and B, storing the result in 'out'
00030 //! assumes that A and B are not aliases of 'out'
00031 template<typename eT>
00032 inline
00033 void
00034 glue_times::apply_noalias(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B)
00035   {
00036   arma_extra_debug_sigprint();
00037   
00038   arma_debug_assert_mul_size(A, B, "matrix multiply");
00039   
00040   out.set_size(A.n_rows,B.n_cols);
00041   gemm<>::apply(out,A,B);
00042   }
00043 
00044 
00045 
00046 template<typename eT>
00047 inline
00048 void
00049 glue_times::apply(Mat<eT>& out, const Mat<eT>& A_in, const Mat<eT>& B_in)
00050   {
00051   arma_extra_debug_sigprint();
00052   
00053   if( (&out != &A_in) && (&out != &B_in) )
00054     {
00055     glue_times::apply_noalias(out,A_in,B_in);
00056     }
00057   else
00058     {
00059     
00060     if( (&out == &A_in) && (&out != &B_in) )
00061       {
00062       Mat<eT> A_copy(A_in);
00063       glue_times::apply_noalias(out,A_copy,B_in);
00064       }
00065     else
00066     if( (&out != &A_in) && (&out == &B_in) )
00067       {
00068       Mat<eT> B_copy(B_in);
00069       glue_times::apply_noalias(out,A_in,B_copy);
00070       }
00071     else
00072     if( (&out == &A_in) && (&out == &B_in) )
00073       {
00074       Mat<eT> tmp(A_in);
00075       glue_times::apply_noalias(out,tmp,tmp);
00076       }
00077 
00078     }
00079     
00080   }
00081 
00082 
00083 template<typename eT>
00084 inline
00085 void
00086 glue_times::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
00087   {
00088   arma_extra_debug_sigprint();
00089 
00090   arma_debug_assert_mul_size(A, B, "matrix multiply");
00091   arma_debug_assert_mul_size(B, C, "matrix multiply");
00092   
00093   if( mul_storage_cost(A,B) <= mul_storage_cost(B,C) )
00094     {
00095     Mat<eT> tmp;
00096     glue_times::apply_noalias(tmp, A, B);
00097     
00098     if(&out != &C)
00099       {
00100       glue_times::apply_noalias(out, tmp, C);
00101       }
00102     else
00103       {
00104       Mat<eT> C_copy = C;
00105       glue_times::apply_noalias(out, tmp, C_copy);
00106       }
00107       
00108     }
00109   else
00110     {
00111     Mat<eT> tmp;
00112     glue_times::apply_noalias(tmp, B, C);
00113     
00114     if(&out != &A)
00115       {
00116       glue_times::apply_noalias(out, A, tmp);
00117       }
00118     else
00119       {
00120       Mat<eT> A_copy = A;
00121       glue_times::apply_noalias(out, A_copy, tmp);
00122       }
00123     }
00124   
00125   }
00126 
00127 
00128 
00129 template<typename T1, typename T2>
00130 void
00131 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00132   {
00133   arma_extra_debug_sigprint();
00134 
00135   typedef typename T1::elem_type eT;
00136 
00137   const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00138 
00139   arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00140 
00141   if(N_mat == 2)
00142     {
00143     const unwrap<T1> tmp1(X.A);
00144     const unwrap<T2> tmp2(X.B);
00145     
00146     glue_times::apply(out, tmp1.M, tmp2.M);
00147     }
00148   else
00149     {
00150     // we have at least three matrices
00151 
00152     const Mat<eT>* ptrs[N_mat];
00153     bool            del[N_mat];
00154   
00155     // takes care of any aliasing problems
00156     mat_ptrs_outcheck<glue_times, Glue<T1,T2,glue_times> >::get_ptrs(ptrs, del, X, &out);
00157   
00158     for(s32 i=0; i<N_mat; ++i)  arma_extra_debug_print( arma_boost::format("ptrs[%d] = %x") % i % ptrs[i] );
00159     for(s32 i=0; i<N_mat; ++i)  arma_extra_debug_print( arma_boost::format(" del[%d] = %d") % i %  del[i] );
00160   
00161   
00162     arma_extra_debug_print( arma_boost::format("required size of 'out':  %d, %d") % ptrs[0]->n_rows % ptrs[N_mat-1]->n_cols );
00163   
00164     int order[N_mat];  for(s32 i=0; i<N_mat; ++i)  order[i] = -1;
00165   
00166     int first_id = 0;
00167     int last_id  = N_mat-1;
00168     int starting_id = -1;
00169   
00170     int mat_count = N_mat;
00171   
00172     int largest_size = 0;
00173   
00174     while(mat_count != 0)
00175       {
00176   
00177       for(s32 i=first_id; i != N_mat; ++i)
00178         {
00179         if(order[i] == -1)  { first_id = i; break; }
00180         }
00181   
00182       for(s32 i=last_id; i != -1; --i)
00183         {
00184         if(order[i] == -1)  { last_id = i; break; }
00185         }
00186   
00187       arma_extra_debug_print();
00188       arma_extra_debug_print(arma_boost::format("mat_count = %d") % mat_count );
00189       arma_extra_debug_print(arma_boost::format("first_id  = %d") % first_id  );
00190       arma_extra_debug_print(arma_boost::format("last_id   = %d") % last_id   );
00191   
00192       if(first_id == last_id)  { order[first_id] = 0; starting_id = first_id; break; }
00193   
00194       s32 storage_cost_wo_last  = mul_storage_cost( *ptrs[ first_id   ], *ptrs[ last_id-1 ] );
00195       s32 storage_cost_wo_first = mul_storage_cost( *ptrs[ first_id+1 ], *ptrs[ last_id   ] );
00196   
00197       if(storage_cost_wo_last < storage_cost_wo_first)
00198         {
00199         order[last_id]  = mat_count-1;
00200         if(storage_cost_wo_last > largest_size)  largest_size = storage_cost_wo_last;
00201         }
00202       else
00203         {
00204         order[first_id] = mat_count-1;
00205         if(storage_cost_wo_first > largest_size)  largest_size = storage_cost_wo_first;
00206         }
00207   
00208       arma_extra_debug_print(arma_boost::format("storage_cost_wo_last  = %d") % storage_cost_wo_last  );
00209       arma_extra_debug_print(arma_boost::format("storage_cost_wo_first = %d") % storage_cost_wo_first );
00210   
00211       arma_extra_debug_print("order = ");
00212       for(s32 i=0; i != N_mat; ++i)  arma_extra_debug_print(order[i]);
00213   
00214       --mat_count;
00215       }
00216   
00217     arma_extra_debug_print("final order = ");
00218     for(s32 i=0; i != N_mat; ++i)  arma_extra_debug_print(order[i]);
00219   
00220     arma_extra_debug_print(arma_boost::format("*** largest_size = %d") % largest_size);
00221     arma_extra_debug_print(arma_boost::format("starting_id  = %d") % starting_id);
00222   
00223   
00224     // multiply based on order
00225     // if there are only three matrices, we need only one temporary store:
00226     //   out = a*b*c translates to:  tmp1 = a*b,  out = tmp1*c
00227     //
00228     // if there are four matrices, we need two temporary stores
00229     //   out = a*b*c*d translates to:  tmp1 = a*b, tmp2 = tmp1*c, out = tmp2*d
00230     //
00231     // if there are five matrices, we need two temporary stores
00232     //   out = a*b*c*d*e translates to:  tmp1 = a*b, tmp2 = tmp1*c, tmp1 = tmp2*d, out = tmp1*e
00233     //
00234     // if there are six matrices, we need two temporary stores
00235     //   out = a*b*c*d*e*f translates to:  tmp1 = a*b, tmp2 = tmp1*c, tmp1 = tmp2*d, tmp2 = tmp1*e, out = tmp2*f
00236     //
00237   
00238     
00239     const u32 N_mul = N_mat - 1;
00240     int mul_count = N_mul;
00241     int current_id = starting_id;
00242   
00243     const Mat<eT>* src_mat_1_ptr = ptrs[current_id];
00244     const Mat<eT>* src_mat_2_ptr = 0;
00245   
00246     // TODO:
00247     // allocate two storage areas (of size 'largest_size'), not two matrices
00248   
00249     
00250     Mat<eT> tmp_mat_1;
00251     Mat<eT> tmp_mat_2;
00252     
00253     Mat<eT>* tmp_mat_1_ptr = &tmp_mat_1;
00254     Mat<eT>* tmp_mat_2_ptr = (N_mul <= 2) ? 0 : &tmp_mat_2;
00255   
00256     Mat<eT>* dest_mat_ptr  = tmp_mat_2_ptr;
00257   
00258     arma_extra_debug_print(arma_boost::format("tmp_mat_1_ptr = %x") % tmp_mat_1_ptr );
00259     arma_extra_debug_print(arma_boost::format("tmp_mat_2_ptr = %x") % tmp_mat_2_ptr );
00260     arma_extra_debug_print(arma_boost::format("&out          = %x") % &out );
00261   
00262     while(mul_count != 0)
00263       {
00264       arma_extra_debug_print("");
00265       arma_extra_debug_print("");
00266       arma_extra_debug_print(arma_boost::format("mul_count = %d") % mul_count);
00267   
00268       arma_extra_debug_print("order = ");
00269       for(s32 i=0; i != N_mat; ++i)  arma_extra_debug_print(order[i]);
00270       arma_extra_debug_print("");
00271   
00272       // only one multiplication left, hence destination matrix is the out matrix
00273       if(mul_count == 1)
00274         {
00275         arma_extra_debug_print("dest_mat = &out");
00276         dest_mat_ptr = &out;
00277         }
00278       else
00279         {
00280         if(dest_mat_ptr == tmp_mat_2_ptr)
00281           {
00282           arma_extra_debug_print("dest_mat_ptr = tmp_mat_2_ptr");
00283           dest_mat_ptr = tmp_mat_1_ptr;
00284           }
00285         else
00286           {
00287           arma_extra_debug_print("dest_mat_ptr = tmp_mat_1_ptr");
00288           dest_mat_ptr = tmp_mat_2_ptr;
00289           }
00290         }
00291   
00292       arma_extra_debug_print(arma_boost::format("dest_mat_ptr = %x") % dest_mat_ptr );
00293   
00294       // search on either side of current_pos for a useable value.  unuseable values are equal to -1
00295       s32 left_val = N_mat;
00296       s32 left_id = -1;
00297   
00298       s32 right_val = N_mat;
00299       s32 right_id = -1;
00300   
00301       // go left from current_pos
00302       for(s32 i=current_id-1; i >= 0; --i)
00303         if( order[i] > order[current_id] ) { left_val = order[i]; left_id = i; break; }
00304   
00305       // go right from current_pos
00306       for(s32 i=current_id+1; i < N_mat; ++i)
00307         if( order[current_id] < order[i] ) { right_val = order[i]; right_id = i; break; }
00308   
00309       arma_extra_debug_print("");
00310       arma_extra_debug_print(arma_boost::format("left_id  = %d") % left_id  );
00311       arma_extra_debug_print(arma_boost::format("left_val = %f") % left_val );
00312   
00313       arma_extra_debug_print("");
00314       arma_extra_debug_print(arma_boost::format("right_id  = %d") % right_id  );
00315       arma_extra_debug_print(arma_boost::format("right_val = %f") % right_val );
00316   
00317   
00318       if(left_val < right_val)
00319         {
00320         // a pre-multiply
00321         src_mat_2_ptr = ptrs[left_id];
00322   
00323         arma_extra_debug_print("");
00324         arma_extra_debug_print(arma_boost::format("case pre-multiply with matrix %d") % left_id);
00325         arma_extra_debug_print(arma_boost::format("required destination size: %d, %d  (%d)") %  src_mat_2_ptr->n_rows % src_mat_1_ptr->n_cols % (src_mat_2_ptr->n_rows * src_mat_1_ptr->n_cols) );
00326   
00327         glue_times::apply_noalias(*dest_mat_ptr, *src_mat_2_ptr, *src_mat_1_ptr);
00328   
00329         order[current_id] = -1;
00330         current_id = left_id;
00331         }
00332       else
00333         {
00334         // a post-multiply
00335         src_mat_2_ptr = ptrs[right_id];
00336   
00337         arma_extra_debug_print("");
00338         arma_extra_debug_print(arma_boost::format("case post-multiply with matrix %d") % right_id);
00339         arma_extra_debug_print(arma_boost::format("required destination size: %d, %d  (%d)") % src_mat_1_ptr->n_rows % src_mat_2_ptr->n_cols % (src_mat_1_ptr->n_rows * src_mat_2_ptr->n_cols) );
00340   
00341         glue_times::apply_noalias(*dest_mat_ptr, *src_mat_1_ptr, *src_mat_2_ptr);
00342   
00343         order[current_id] = -1;
00344         current_id = right_id;
00345         }
00346   
00347       // update pointer to source matrix: must point to last multiplication result
00348       src_mat_1_ptr = dest_mat_ptr;
00349   
00350       --mul_count;
00351       }
00352   
00353   
00354     for(s32 i=0; i<N_mat; ++i)
00355       {
00356       if(del[i] == true)
00357         {
00358         arma_extra_debug_print(arma_boost::format("delete mat_ptr[%d]") % i );
00359         delete ptrs[i];
00360         }
00361       }
00362     }
00363   }
00364 
00365 
00366 
00367 template<typename eT>
00368 inline
00369 void
00370 glue_times::apply(Mat<eT>& out, const Glue<Mat<eT>,Mat<eT>,glue_times>& X)
00371   {
00372   glue_times::apply(out, X.A, X.B);
00373   }
00374 
00375 
00376 
00377 template<typename eT>
00378 inline
00379 void
00380 glue_times::apply(Mat<eT>& out, const Glue< Glue<Mat<eT>,Mat<eT>, glue_times>, Mat<eT>, glue_times>& X)
00381   {
00382   glue_times::apply(out, X.A.A, X.A.B, X.B);
00383   }
00384 
00385 
00386 
00387 template<typename eT>
00388 inline
00389 void
00390 glue_times::apply_inplace(Mat<eT>& out, const Mat<eT>& B)
00391   {
00392   arma_extra_debug_sigprint();
00393   
00394   arma_debug_assert_mul_size(out, B, "matrix multiply");
00395   
00396   if(out.n_cols == B.n_cols)
00397     {
00398     podarray<eT> tmp(out.n_cols);
00399     eT* tmp_rowdata = tmp.memptr();
00400     
00401     for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00402       {
00403       for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00404         {
00405         tmp_rowdata[out_col] = out.at(out_row,out_col);
00406         }
00407       
00408       for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00409         {
00410         const eT* B_coldata = B.colptr(B_col);
00411         
00412         eT val = eT(0);
00413         for(u32 i=0; i < B.n_rows; ++i)
00414           {
00415           val += tmp_rowdata[i] * B_coldata[i];
00416           }
00417         
00418         out.at(out_row,B_col) = val;
00419         }
00420       }
00421     
00422     }
00423   else
00424     {
00425     Mat<eT> tmp = out;
00426     glue_times::apply(out, tmp, B);
00427     }
00428   
00429   }
00430 
00431 
00432 
00433 template<typename T1, typename op_type>
00434 inline
00435 void
00436 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const Op<T1, op_type>& X)
00437   {
00438   arma_extra_debug_sigprint();
00439   
00440   typedef typename T1::elem_type eT;
00441   
00442   const Mat<eT> tmp(X);
00443   glue_times::apply(out, out, tmp);
00444   }
00445 
00446 
00447 
00448 template<typename T1, typename T2, typename glue_type>
00449 inline
00450 void
00451 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_type>& X)
00452   {
00453   arma_extra_debug_sigprint();
00454   
00455   out = out * X;
00456   }
00457 
00458 
00459 
00460 //! out = T1 * trans(T2)
00461 template<typename T1, typename T2>
00462 inline
00463 void
00464 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_trans>, glue_times>& X)
00465   {
00466   arma_extra_debug_sigprint();
00467   
00468   
00469   typedef typename T1::elem_type eT;
00470   
00471   // checks for aliases are done later
00472   
00473   const unwrap<T1> tmp1(X.A);
00474   const unwrap<T2> tmp2(X.B.m);
00475   
00476   const Mat<eT>& A = tmp1.M;
00477   const Mat<eT>& B = tmp2.M;
00478   
00479   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "matrix multiply");
00480     
00481   if( (A.n_rows*B.n_rows) > 0)
00482     {
00483     if(&A != &B)   // A*B'
00484       {
00485       unwrap_check< Mat<eT> > A_safe_tmp(A, out);
00486       unwrap_check< Mat<eT> > B_safe_tmp(B, out);
00487       
00488       const Mat<eT>& A_safe = A_safe_tmp.M;
00489       const Mat<eT>& B_safe = B_safe_tmp.M;
00490       
00491       out.set_size(A_safe.n_rows, B_safe.n_rows);
00492       
00493       gemm<false,true>::apply(out, A, B);
00494       }
00495     else   // A*A'
00496       {
00497       arma_extra_debug_print("glue_times::apply(): detected A*A'");
00498       
00499       Mat<eT> tmp;
00500       op_trans::apply(tmp,A);
00501       
00502       // no aliasing problem
00503       out.set_size(A.n_rows, A.n_rows);
00504       
00505       for(u32 row=0; row != A.n_rows; ++row)
00506         {
00507         for(u32 col=0; col <= row; ++col)
00508           {
00509           const eT* coldata1 = tmp.colptr(row);
00510           const eT* coldata2 = tmp.colptr(col);
00511         
00512           eT val = eT(0);
00513           for(u32 i=0; i < tmp.n_rows; ++i)
00514             {
00515             val += coldata1[i] * coldata2[i];
00516             }
00517           
00518           out.at(row,col) = val;
00519           out.at(col,row) = val;
00520           }
00521         }
00522         
00523       }
00524     
00525     }
00526   
00527   }
00528 
00529 
00530 
00531 //! out = trans(T1) * T2
00532 template<typename T1, typename T2>
00533 inline
00534 void
00535 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, T2, glue_times>& X)
00536   {
00537   arma_extra_debug_sigprint();
00538   
00539   typedef typename T1::elem_type eT;
00540   
00541   const unwrap_check<T1> tmp1(X.A.m, out);
00542   const unwrap_check<T2> tmp2(X.B,   out);
00543   
00544   const Mat<eT>& A = tmp1.M;
00545   const Mat<eT>& B = tmp2.M;
00546   
00547   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiply");
00548   
00549   if( (A.n_cols*B.n_cols) > 0 )
00550     {
00551     out.set_size(A.n_cols, B.n_cols);
00552     
00553     gemm<true,false>::apply(out, A, B);
00554     }
00555     
00556   }
00557 
00558 
00559 
00560 //! out = trans(T1) * trans(T2)
00561 template<typename T1, typename T2>
00562 inline
00563 void
00564 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, Op<T2,op_trans>, glue_times>& X)
00565   {
00566   arma_extra_debug_sigprint();
00567   
00568   typedef typename T1::elem_type eT;
00569   
00570   const unwrap_check<T1> tmp1(X.A.m, out);
00571   const unwrap_check<T2> tmp2(X.B.m, out);
00572   
00573   const Mat<eT>& A = tmp1.M;
00574   const Mat<eT>& B = tmp2.M;
00575   
00576   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_cols, B.n_rows, "matrix multiply");
00577   
00578   if( (A.n_cols*B.n_rows) > 0 )
00579     {
00580     out.set_size(A.n_cols, B.n_rows);
00581     
00582     gemm<true,true>::apply(out, A, B);
00583     
00584     }
00585     
00586   }
00587 
00588 
00589 
00590 
00591 //! out = -T1 * T2
00592 template<typename T1, typename T2>
00593 inline
00594 void
00595 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1, op_neg>, T2, glue_times>& X)
00596   {
00597   arma_extra_debug_sigprint();
00598   
00599   typedef typename T1::elem_type eT;
00600   
00601   const unwrap_check<T1> tmp1(X.A.m, out);
00602   const unwrap_check<T2> tmp2(X.B,   out);
00603   
00604   const Mat<eT>& A = tmp1.M;
00605   const Mat<eT>& B = tmp2.M;
00606 
00607   glue_times::apply(out, A, B);
00608   
00609   const u32 n_elem = out.n_elem;
00610   for(u32 i=0; i<n_elem; ++i)
00611     {
00612     out[i] = -out[i];
00613     }
00614   }
00615 
00616 
00617 
00618 
00619 template<typename eT>
00620 inline
00621 eT
00622 glue_times::direct_rowvec_mat_colvec
00623   (
00624   const eT*            A_mem,
00625   const Mat<eT>& B,
00626   const eT*            C_mem
00627   )
00628   {
00629   arma_extra_debug_sigprint();
00630   
00631   const u32 cost_AB = B.n_cols;
00632   const u32 cost_BC = B.n_rows;
00633   
00634   if(cost_AB <= cost_BC)
00635     {
00636     podarray<eT> tmp(B.n_cols);
00637     
00638     for(u32 col=0; col<B.n_cols; ++col)
00639       {
00640       const eT* B_coldata = B.colptr(col);
00641       
00642       eT val = eT(0);
00643       for(u32 i=0; i<B.n_rows; ++i)
00644         {
00645         val += A_mem[i] * B_coldata[i];
00646         }
00647         
00648       tmp[col] = val;
00649       }
00650     
00651     return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
00652     }
00653   else
00654     {
00655     podarray<eT> tmp(B.n_rows);
00656     
00657     for(u32 row=0; row<B.n_rows; ++row)
00658       {
00659       eT val = eT(0);
00660       for(u32 col=0; col<B.n_cols; ++col)
00661         {
00662         val += B.at(row,col) * C_mem[col];
00663         }
00664       
00665       tmp[row] = val;
00666       }
00667     
00668     return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
00669     }
00670   
00671   
00672   }
00673 
00674 
00675 
00676 template<typename eT>
00677 inline
00678 eT
00679 glue_times::direct_rowvec_diagmat_colvec
00680   (
00681   const eT*            A_mem,
00682   const Mat<eT>& B,
00683   const eT*            C_mem
00684   )
00685   {
00686   arma_extra_debug_sigprint();
00687   
00688   eT val = eT(0);
00689 
00690   for(u32 i=0; i<B.n_rows; ++i)
00691     {
00692     val += A_mem[i] * B.at(i,i) * C_mem[i];
00693     }
00694 
00695   return val;
00696   }
00697 
00698 
00699 
00700 template<typename eT>
00701 inline
00702 eT
00703 glue_times::direct_rowvec_invdiagmat_colvec
00704   (
00705   const eT*            A_mem,
00706   const Mat<eT>& B,
00707   const eT*            C_mem
00708   )
00709   {
00710   arma_extra_debug_sigprint();
00711   
00712   eT val = eT(0);
00713 
00714   for(u32 i=0; i<B.n_rows; ++i)
00715     {
00716     val += (A_mem[i] * C_mem[i]) / B.at(i,i);
00717     }
00718 
00719   return val;
00720   }
00721 
00722 
00723 
00724 template<typename eT>
00725 inline
00726 eT
00727 glue_times::direct_rowvec_invdiagvec_colvec
00728   (
00729   const eT*            A_mem,
00730   const Mat<eT>& B,
00731   const eT*            C_mem
00732   )
00733   {
00734   arma_extra_debug_sigprint();
00735   
00736   const eT* B_mem = B.mem;
00737   
00738   eT val = eT(0);
00739 
00740   for(u32 i=0; i<B.n_elem; ++i)
00741     {
00742     val += (A_mem[i] * C_mem[i]) / B_mem[i];
00743     }
00744 
00745   return val;
00746   }
00747 
00748 
00749 
00750 //
00751 // matrix multiplication with different element types
00752 
00753 template<typename eT1, typename eT2>
00754 inline
00755 void
00756 glue_times::apply_mixed(Mat<typename promote_type<eT1,eT2>::result>& out, const Mat<eT1>& X, const Mat<eT2>& Y)
00757   {
00758   arma_extra_debug_sigprint();
00759   
00760   typedef typename promote_type<eT1,eT2>::result out_eT;
00761   
00762   arma_debug_assert_mul_size(X,Y, "matrix multiply");
00763   
00764   out.set_size(X.n_rows,Y.n_cols);
00765   gemm_mixed<>::apply(out, X, Y);
00766   }
00767 
00768 
00769 
00770 //
00771 // glue_times_diag
00772 
00773 
00774 template<typename T1, typename T2>
00775 inline
00776 void
00777 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const T1& A_orig, const Op<T2,op_diagmat>& B_orig)
00778   {
00779   arma_extra_debug_sigprint();
00780     
00781   isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00782   
00783   const unwrap_check<T1> tmp1(A_orig,   out);
00784   const unwrap_check<T2> tmp2(B_orig.m, out);
00785   
00786   typedef typename T1::elem_type eT;
00787   
00788   const Mat<eT>& A = tmp1.M;
00789   const Mat<eT>& B = tmp2.M;
00790   
00791   arma_debug_check( (B.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00792   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00793   
00794   out.set_size(A.n_rows, B.n_cols);
00795   
00796   for(u32 col=0; col<A.n_cols; ++col)
00797     {
00798     const eT val = B.at(col,col);
00799     
00800     const eT*   A_coldata =   A.colptr(col);
00801           eT* out_coldata = out.colptr(col);
00802     
00803     for(u32 row=0; row<B.n_rows; ++row)
00804       {
00805       out_coldata[row] = A_coldata[row] * val;
00806       }
00807     
00808     }
00809   
00810   }
00811 
00812 
00813 
00814 template<typename T1, typename T2>
00815 inline
00816 void
00817 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const T2& B_orig)
00818   {
00819   arma_extra_debug_sigprint();
00820   
00821   isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00822   
00823   const unwrap_check<T1> tmp1(A_orig.m, out);
00824   const unwrap_check<T2> tmp2(B_orig,   out);
00825   
00826   typedef typename T1::elem_type eT;
00827   
00828   const Mat<eT>& A = tmp1.M;
00829   const Mat<eT>& B = tmp2.M;
00830   
00831   arma_debug_check( (A.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00832   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00833   
00834   out.set_size(A.n_rows, B.n_cols);
00835   
00836   
00837   for(u32 col=0; col<A.n_cols; ++col)
00838     {
00839     const eT*   B_coldata =   B.colptr(col);
00840           eT* out_coldata = out.colptr(col);
00841     
00842     for(u32 row=0; row<B.n_rows; ++row)
00843       {
00844       out_coldata[row] = A.at(row,row) * B_coldata[row];
00845       }
00846     
00847     }
00848 
00849   }
00850 
00851 
00852 
00853 template<typename T1, typename T2>
00854 inline
00855 void
00856 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const Op<T2,op_diagmat>& B_orig)
00857   {
00858   arma_extra_debug_sigprint();
00859   
00860   isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00861   
00862   unwrap_check<T1> tmp1(A_orig.m, out);
00863   unwrap_check<T2> tmp2(B_orig.m, out);
00864   
00865   typedef typename T1::elem_type eT;
00866   
00867   const Mat<eT>& A = tmp1.M;
00868   const Mat<eT>& B = tmp2.M;
00869   
00870   arma_debug_check( !A.is_square() || !B.is_square(), "glue_times_diag::apply(): incompatible matrix dimensions" );
00871   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00872   
00873   out.zeros(A.n_rows, B.n_cols);
00874   
00875   for(u32 i=0; i<A.n_rows; ++i)
00876     {
00877     out.at(i,i) = A.at(i,i) * B.at(i,i);
00878     }
00879   }
00880 
00881 
00882 
00883 template<typename T1, typename T2>
00884 inline
00885 void 
00886 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_diagmat>, glue_times_diag>& X)
00887   {
00888   glue_times_diag::apply(out, X.A, X.B);
00889   }
00890 
00891 
00892 
00893 template<typename T1, typename T2>
00894 inline
00895 void
00896 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, T2, glue_times_diag>& X)
00897   {
00898   glue_times_diag::apply(out, X.A, X.B);
00899   }
00900 
00901 
00902 
00903 template<typename T1, typename T2>
00904 inline
00905 void
00906 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, Op<T2,op_diagmat>, glue_times_diag>& X)
00907   {
00908   glue_times_diag::apply(out, X.A, X.B);
00909   }
00910 
00911 
00912 
00913 //
00914 // glue_times_vec
00915 
00916 
00917 
00918 template<typename eT>
00919 inline
00920 void
00921 glue_times_vec::mul_col_row(Mat<eT>& out, const eT* A, const eT* B)
00922   {
00923   const u32 n_rows = out.n_rows;
00924   const u32 n_cols = out.n_cols;
00925   
00926   for(u32 col=0; col < n_cols; ++col)
00927     {
00928     const eT val = B[col];
00929     
00930     eT* out_coldata = out.colptr(col);
00931     
00932     for(u32 row=0; row < n_rows; ++row)
00933       {
00934       out_coldata[row] = A[row] * val;
00935       }
00936     }
00937   
00938   }
00939 
00940 
00941 
00942 template<typename T1>
00943 inline
00944 void
00945 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Col<typename T1::elem_type>,glue_times_vec>& X)
00946   {
00947   arma_extra_debug_sigprint();
00948   
00949   typedef typename T1::elem_type eT;
00950   
00951   unwrap_check< T1 >               tmp1(X.A, out);
00952   unwrap_check< Col<eT> > tmp2(X.B, out);
00953   
00954   const Mat<eT>& A = tmp1.M;
00955   const Col<eT>& B = tmp2.M;
00956 
00957   arma_debug_assert_mul_size(A, B, "vector multiply");
00958   
00959   out.set_size(A.n_rows, 1);
00960   
00961   //gemm<>::apply(out,A,B);  // NOTE: B is interpreted as a Mat
00962   gemv<>::apply(out.memptr(), A, B.mem);
00963   }
00964 
00965 
00966 
00967 template<typename T1>
00968 inline
00969 void
00970 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Row<typename T1::elem_type>,glue_times_vec>& X)
00971   {
00972   arma_extra_debug_sigprint();
00973   
00974   // T1 * rowvec makes sense only if T1 ends up being a matrix with one column (i.e. a column vector)
00975   
00976   typedef typename T1::elem_type eT;
00977   
00978   unwrap_check< T1 >               tmp1(X.A, out);
00979   unwrap_check< Row<eT> > tmp2(X.B, out);
00980   
00981   const Mat<eT>& A = tmp1.M;
00982   const Mat<eT>& B = tmp2.M;   // NOTE: interpretation of a Row as a Mat
00983   
00984   arma_debug_assert_mul_size(A, B, "vector multiply");
00985   
00986   out.set_size(A.n_rows, B.n_cols);
00987   
00988   glue_times_vec::mul_col_row(out, A.mem, B.mem);
00989   }
00990 
00991 
00992 
00993 template<typename T1>
00994 inline
00995 void
00996 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Col<typename T1::elem_type>,T1,glue_times_vec>& X)
00997   {
00998   arma_extra_debug_sigprint();
00999   
01000   // colvec * T1 makes sense only if T1 ends up being a matrix with one row (i.e. a row vector)
01001   
01002   
01003   typedef typename T1::elem_type eT;
01004   
01005   unwrap_check< Col<eT> > tmp1(X.A, out);
01006   unwrap_check< T1 >               tmp2(X.B, out);
01007   
01008   const Mat<eT>& A = tmp1.M;   // NOTE: interpretation of a Col as a Mat
01009   const Mat<eT>& B = tmp2.M;
01010   
01011   arma_debug_assert_mul_size(A, B, "vector multiply");
01012   
01013   out.set_size(A.n_rows, B.n_cols);
01014   
01015   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01016   }
01017 
01018 
01019 
01020 template<typename T1>
01021 inline
01022 void
01023 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Row<typename T1::elem_type>,T1,glue_times_vec>& X)
01024   {
01025   arma_extra_debug_sigprint();
01026   
01027   typedef typename T1::elem_type eT;
01028   
01029   unwrap_check< Row<eT> > tmp1(X.A, out);
01030   unwrap_check< T1 >               tmp2(X.B, out);
01031   
01032   const Row<eT>& A = tmp1.M;
01033   const Mat<eT>& B = tmp2.M;
01034   
01035   arma_debug_assert_mul_size(A, B, "vector multiply");
01036   
01037   out.set_size(A.n_rows, B.n_cols);
01038   
01039 //         eT* out_mem = out.memptr();
01040 //   const eT* A_mem   = A.mem;
01041 //   
01042 //   const u32 A_n_cols = A.n_cols;
01043 //   const u32 B_n_cols = B.n_cols;
01044 //   
01045 //   for(u32 col=0; col<B_n_cols; ++col)
01046 //     {
01047 //     const eT* B_coldata = B.colptr(col);
01048 //     
01049 //     eT val = eT(0);
01050 //     for(u32 i=0; i<A_n_cols; ++i)
01051 //       {
01052 //       val += A_mem[i] * B_coldata[i];
01053 //       }
01054 //       
01055 //     out_mem[col] = val;
01056 //     }
01057   
01058   gemv<true>::apply(out.memptr(), B, A.mem);
01059   }
01060 
01061 
01062 
01063 template<typename eT>
01064 inline
01065 void
01066 glue_times_vec::apply(Mat<eT>& out, const Glue<Col<eT>,Row<eT>,glue_times_vec>& X)
01067   {
01068   arma_extra_debug_sigprint();
01069   
01070   unwrap_check< Col<eT> > tmp1(X.A, out);
01071   unwrap_check< Row<eT> > tmp2(X.B, out);
01072   
01073   const Col<eT>& A = tmp1.M;
01074   const Row<eT>& B = tmp2.M;
01075   
01076   arma_debug_assert_mul_size(A, B, "vector multiply");
01077   
01078   out.set_size(A.n_rows, B.n_cols);
01079   
01080   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01081   }
01082 
01083 
01084 
01085 template<typename eT>
01086 inline
01087 void
01088 glue_times_vec::apply(Mat<eT>& out, const Glue< Op<Row<eT>, op_trans>, Row<eT>, glue_times_vec>& X)
01089   {
01090   arma_extra_debug_sigprint();
01091   
01092   unwrap_check< Row<eT> > tmp1(X.A.m, out);
01093   unwrap_check< Row<eT> > tmp2(X.B,   out);
01094   
01095   const Row<eT>& A = tmp1.M;
01096   const Row<eT>& B = tmp2.M;
01097   
01098   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01099   
01100   out.set_size(A.n_cols, B.n_cols);
01101   
01102   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01103   }
01104 
01105 
01106 
01107 template<typename eT>
01108 inline
01109 void
01110 glue_times_vec::apply(Mat<eT>& out, const Glue< Col<eT>, Op<Col<eT>, op_trans>, glue_times_vec>& X)
01111   {
01112   arma_extra_debug_sigprint();
01113   
01114   unwrap_check< Col<eT> > tmp1(X.A,   out);
01115   unwrap_check< Col<eT> > tmp2(X.B.m, out);
01116   
01117   const Col<eT>& A = tmp1.M;
01118   const Col<eT>& B = tmp2.M;
01119   
01120   arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "vector multiply");
01121   
01122   out.set_size(A.n_rows, B.n_rows);
01123   
01124   glue_times_vec::mul_col_row(out, A.mem, B.mem);
01125   }
01126 
01127 
01128 
01129 template<typename T1>
01130 inline
01131 void
01132 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1, op_trans>, Col<typename T1::elem_type>,glue_times_vec>& X)
01133   {
01134   arma_extra_debug_sigprint();
01135   
01136   typedef typename T1::elem_type eT;
01137   
01138   unwrap_check< T1 >               tmp1(X.A.m, out);
01139   unwrap_check< Col<eT> > tmp2(X.B,   out);
01140   
01141   const Mat<eT>& A = tmp1.M;
01142   const Col<eT>& B = tmp2.M;
01143 
01144   arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01145   
01146   out.set_size(A.n_cols, B.n_cols);
01147   
01148 //         eT* out_mem = out.memptr();
01149 //   const eT* B_mem   = B.mem;
01150 //   
01151 //   const u32 A_n_cols = A.n_cols;
01152 //   const u32 B_n_rows = B.n_rows;
01153 //   
01154 //   for(u32 col=0; col < A_n_cols; ++col)
01155 //     {
01156 //     const eT* A_col = A.colptr(col);
01157 //     
01158 //     eT val = eT(0);
01159 //     for(u32 row=0; row<B_n_rows; ++row)
01160 //       {
01161 //       val += A_col[row] * B_mem[row];
01162 //       }
01163 //     
01164 //     out_mem[col] = val;
01165 //     }
01166   
01167   gemv<true>::apply(out.memptr(), A, B.mem);
01168   }
01169 
01170 
01171 
01172 //! @}