op_sort_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup op_sort
00017 //! @{
00018 
00019 
00020 // using qsort() rather than std::sort() for now.
00021 // std::sort() will be used when a Random Access Iterator wrapper for plain arrays is ready
00022 // otherwise using std::sort() now would entail copying elements to/from std::vector
00023 
00024 template<typename eT>
00025 class arma_qsort_helper
00026   {
00027   public:
00028   
00029   static
00030   int
00031   ascend_compare(const void* A_orig, const void* B_orig)
00032     {
00033     const eT& A = *(static_cast<const eT*>(A_orig));
00034     const eT& B = *(static_cast<const eT*>(B_orig));
00035     
00036     if(A < B)
00037       {
00038       return -1;
00039       }
00040     else
00041     if(A > B)
00042       {
00043       return +1;
00044       }
00045     else
00046       {
00047       return 0;
00048       }
00049     }
00050   
00051   
00052   
00053   static
00054   int
00055   descend_compare(const void* A_orig, const void* B_orig)
00056     {
00057     const eT& A = *(static_cast<const eT*>(A_orig));
00058     const eT& B = *(static_cast<const eT*>(B_orig));
00059     
00060     if(A < B)
00061       {
00062       return +1;
00063       }
00064     else
00065     if(A > B)
00066       {
00067       return -1;
00068       }
00069     else
00070       {
00071       return 0;
00072       }
00073     }
00074   
00075   
00076   };
00077 
00078 
00079 
00080 //template<>
00081 template<typename T>
00082 class arma_qsort_helper< std::complex<T> >
00083   {
00084   public:
00085   
00086   typedef typename std::complex<T> eT;
00087   
00088   
00089   static
00090   int
00091   ascend_compare(const void* A_orig, const void* B_orig)
00092     {
00093     const eT& A = *(static_cast<const eT*>(A_orig));
00094     const eT& B = *(static_cast<const eT*>(B_orig));
00095     
00096     const T abs_A = std::abs(A);
00097     const T abs_B = std::abs(B);
00098     
00099     if(abs_A < abs_B)
00100       {
00101       return -1;
00102       }
00103     else
00104     if(abs_A > abs_B)
00105       {
00106       return +1;
00107       }
00108     else
00109       {
00110       return 0;
00111       }
00112     }
00113   
00114   
00115   
00116   static
00117   int
00118   descend_compare(const void* A_orig, const void* B_orig)
00119     {
00120     const eT& A = *(static_cast<const eT*>(A_orig));
00121     const eT& B = *(static_cast<const eT*>(B_orig));
00122     
00123     const T abs_A = std::abs(A);
00124     const T abs_B = std::abs(B);
00125     
00126     if(abs_A < abs_B)
00127       {
00128       return +1;
00129       }
00130     else
00131     if(abs_A > abs_B)
00132       {
00133       return -1;
00134       }
00135     else
00136       {
00137       return 0;
00138       }
00139     }
00140   
00141   
00142   };
00143 
00144 
00145 
00146 template<typename eT>
00147 inline 
00148 void
00149 op_sort::direct_sort(eT* X, const u32 n_elem, const u32 sort_type)
00150   {
00151   arma_extra_debug_sigprint();
00152   
00153   if(sort_type == 0)
00154     {
00155     std::qsort(X, n_elem, sizeof(eT), arma_qsort_helper<eT>::ascend_compare);
00156     }
00157   else
00158     {
00159     std::qsort(X, n_elem, sizeof(eT), arma_qsort_helper<eT>::descend_compare);
00160     }
00161   }
00162 
00163 
00164 
00165 template<typename T1>
00166 inline
00167 void
00168 op_sort::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sort>& in)
00169   {
00170   arma_extra_debug_sigprint();
00171   
00172   typedef typename T1::elem_type eT;
00173   
00174   unwrap<T1> tmp(in.m);
00175   const Mat<eT>& X = tmp.M;
00176   
00177   const u32 sort_type = in.aux_u32_a;
00178   const u32 dim       = in.aux_u32_b;
00179   
00180   arma_debug_check( (sort_type > 1), "op_sort::apply(): incorrect usage. sort_type must be 0 or 1");
00181   arma_debug_check( (dim > 1),       "op_sort::apply(): incorrect usage. dim must be 0 or 1"      );
00182   
00183   
00184   if(dim == 0)  // column-wise
00185     {
00186     arma_extra_debug_print("op_sort::apply(), dim = 0");
00187     
00188     out = X;
00189     
00190     for(u32 col=0; col<out.n_cols; ++col)
00191       {
00192       op_sort::direct_sort( out.colptr(col), out.n_rows, sort_type );
00193       }
00194     }
00195   else
00196   if(dim == 1)  // row-wise
00197     {
00198     if(X.n_rows != 1)  // not a row vector
00199       {
00200       arma_extra_debug_print("op_sort::apply(), dim = 1, generic");
00201       
00202       out.set_size(X.n_rows, X.n_cols);
00203       podarray<eT> tmp_array(X.n_cols);
00204       
00205       for(u32 row=0; row<out.n_rows; ++row)
00206         {
00207         
00208         for(u32 col=0; col<out.n_cols; ++col)
00209           {
00210           tmp_array[col] = X.at(row,col);
00211           }
00212         
00213         op_sort::direct_sort( tmp_array.memptr(), out.n_cols, sort_type );
00214         
00215         for(u32 col=0; col<out.n_cols; ++col)
00216           {
00217           out.at(row,col) = tmp_array[col];
00218           }
00219         
00220         }
00221       }
00222     else  // a row vector
00223       {
00224       arma_extra_debug_print("op_sort::apply(), dim = 1, vector specific");
00225       
00226       out = X;
00227       op_sort::direct_sort(out.memptr(), out.n_elem, sort_type);
00228       }
00229     }
00230   
00231   }
00232 
00233 
00234 //! @}