op_trans_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup op_trans
00017 //! @{
00018 
00019 
00020 //! Immediate transpose of a dense matrix
00021 template<typename eT>
00022 inline
00023 void
00024 op_trans::apply_noalias(Mat<eT>& out, const Mat<eT>& A)
00025   {
00026   arma_extra_debug_sigprint();
00027   
00028   const u32 A_n_cols = A.n_cols;
00029   const u32 A_n_rows = A.n_rows;
00030   
00031   
00032   out.set_size(A_n_cols, A_n_rows);
00033   
00034   if( (A_n_cols == 1) || (A_n_rows == 1) )
00035     {
00036     syslib::copy_elem( out.memptr(), A.mem, A.n_elem );
00037     }
00038   else
00039     {
00040     for(u32 in_row = 0; in_row<A_n_rows; ++in_row)
00041       {
00042       const u32 out_col = in_row;
00043     
00044       for(u32 in_col = 0; in_col<A_n_cols; ++in_col)
00045         {
00046         const u32 out_row = in_col;
00047         out.at(out_row, out_col) = A.at(in_row, in_col);
00048         }
00049       }
00050     }
00051   
00052   }
00053 
00054 
00055 
00056 template<typename eT>
00057 inline
00058 void
00059 op_trans::apply(Mat<eT>& out, const Mat<eT>& A)
00060   {
00061   arma_extra_debug_sigprint();
00062   
00063   if(&out != &A)
00064     {
00065     op_trans::apply_noalias(out, A);
00066     }
00067   else
00068     {
00069     if(out.n_rows == out.n_cols)
00070       {
00071       arma_extra_debug_print("doing in-place transpose of a square matrix");
00072       
00073       const u32 n_rows = out.n_rows;
00074       const u32 n_cols = out.n_cols;
00075       
00076       for(u32 col=0; col<n_cols; ++col)
00077         {
00078         eT* coldata = out.colptr(col);
00079         
00080         for(u32 row=(col+1); row<n_rows; ++row)
00081           {
00082           std::swap( out.at(col,row), coldata[row] );
00083           }
00084         }
00085       }
00086     else
00087       {
00088       const Mat<eT> A_copy = A;
00089       op_trans::apply_noalias(out, A_copy);
00090       }
00091     }
00092   }
00093 
00094 
00095 
00096 template<typename T1>
00097 inline
00098 void
00099 op_trans::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trans>& in)
00100   {
00101   arma_extra_debug_sigprint();
00102   
00103   typedef typename T1::elem_type eT;
00104   
00105   const unwrap<T1> tmp(in.m);
00106   const Mat<eT>& A = tmp.M;
00107   
00108   if(&out != &A)
00109     {
00110     op_trans::apply_noalias(out, A);
00111     }
00112   else
00113     {
00114     if(out.n_rows == out.n_cols)
00115       {
00116       arma_extra_debug_print("doing in-place transpose of a square matrix");
00117       
00118       const u32 n_rows = out.n_rows;
00119       const u32 n_cols = out.n_cols;
00120       
00121       for(u32 col=0; col<n_cols; ++col)
00122         {
00123         eT* coldata = out.colptr(col);
00124         
00125         for(u32 row=(col+1); row<n_rows; ++row)
00126           {
00127           std::swap( out.at(col,row), coldata[row] );
00128           }
00129         }
00130       }
00131     else
00132       {
00133       const Mat<eT> A_copy = A;
00134       op_trans::apply_noalias(out, A_copy);
00135       }
00136     }
00137   
00138   }
00139 
00140 
00141 
00142 template<typename T1, typename T2>
00143 inline
00144 void
00145 op_trans::apply(Mat<typename T1::elem_type>& out, const Op< Glue<T1,T2,glue_plus>, op_trans>& in)
00146   {
00147   arma_extra_debug_sigprint();
00148   
00149   typedef typename T1::elem_type eT;
00150   
00151   isnt_same_type<eT, typename T2::elem_type>::check();
00152   
00153   const unwrap_check<T1> tmp1(in.m.A, out);
00154   const unwrap_check<T2> tmp2(in.m.B, out);
00155   
00156   const Mat<eT>& A = tmp1.M;
00157   const Mat<eT>& B = tmp2.M;
00158   
00159   arma_debug_assert_same_size(A, B, "matrix addition");
00160   
00161   out.set_size(A.n_cols, A.n_rows);
00162   
00163   if( ( (A.n_rows == 1) || (A.n_cols == 1) ) && ( (B.n_rows == 1) || (B.n_cols == 1) ) )
00164     {
00165     for(u32 i=0; i<A.n_elem; ++i)
00166       {
00167       out[i] = A[i] + B[i];
00168       }
00169     }
00170   else
00171     {
00172     const u32 A_n_cols = A.n_cols;
00173     const u32 A_n_rows = A.n_rows;
00174   
00175     for(u32 col=0; col<A_n_cols; ++col)
00176       {
00177       const u32 out_row = col;
00178       for(u32 row=0; row<A_n_rows; ++row)
00179         {
00180         const u32 out_col = row;
00181         out.at(out_row, out_col) = A.at(row,col) + B.at(row,col);
00182         }
00183       
00184       }
00185     
00186     }
00187   
00188   }
00189 
00190 
00191 
00192 // inline void op_trans::apply_inplace(mat &X)
00193 //   {
00194 //   arma_extra_debug_sigprint();
00195 //   
00196 //   if((X.n_rows == 1) || (X.n_cols == 1))
00197 //     {
00198 //     const u32 old_n_rows = X.n_rows;
00199 //     access::rw(X.n_rows) = X.n_cols;
00200 //     access::rw(X.n_cols) = old_n_rows;
00201 //     }
00202 //   else
00203 //   if(X.n_rows == X.n_cols)
00204 //     {
00205 //     for(u32 col=0; col < X.n_cols; ++col)
00206 //       {
00207 //       double* X_coldata = X.colptr(col);
00208 //       
00209 //       for(u32 row=(col+1); row < X.n_rows; ++row)
00210 //         {
00211 //         std::swap( A.at(col,row), A_coldata[row] );
00212 //         }
00213 //       }
00214 //     }
00215 //   else
00216 //     {
00217 //     mat tmp = trans(X);
00218 //     
00219 //     if(X.mem != X.mem_local)
00220 //       {
00221 //       double* old_mem = X.memptr();
00222 //       access::rw(X.mem) = tmp.memptr();
00223 //       access::rw(tmp.mem) = old_mem;
00224 //       }
00225 //     else
00226 //       {
00227 //       X = tmp;
00228 //       }
00229 //     }
00230 //   
00231 //   }
00232 
00233 //! @}