op_dot_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2009 NICTA
00002 // 
00003 // Authors:
00004 // - Conrad Sanderson (conradsand at ieee dot org)
00005 // 
00006 // This file is part of the Armadillo C++ library.
00007 // It is provided without any warranty of fitness
00008 // for any purpose. You can redistribute this file
00009 // and/or modify it under the terms of the GNU
00010 // Lesser General Public License (LGPL) as published
00011 // by the Free Software Foundation, either version 3
00012 // of the License or (at your option) any later version.
00013 // (see http://www.opensource.org/licenses for more info)
00014 
00015 
00016 //! \addtogroup op_dot
00017 //! @{
00018 
00019 
00020 
00021 //! for two arrays
00022 template<typename eT>
00023 inline
00024 arma_pure
00025 eT
00026 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B)
00027   {
00028   arma_extra_debug_sigprint();
00029   
00030   eT val1 = eT(0);
00031   eT val2 = eT(0);
00032   
00033   u32 i,j;
00034   for(i=0, j=1; j<n_elem; i+=2, j+=2)
00035     {
00036     val1 += A[i] * B[i];
00037     val2 += A[j] * B[j];
00038     }
00039   
00040   if(i < n_elem)
00041     {
00042     val1 += A[i] * B[i];
00043     }
00044   
00045   return val1+val2;
00046   }
00047 
00048 
00049 
00050 //! for three arrays
00051 template<typename eT>
00052 inline
00053 arma_pure
00054 eT
00055 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B, const eT* C)
00056   {
00057   arma_extra_debug_sigprint();
00058   
00059   eT val = eT(0);
00060   
00061   for(u32 i=0; i<n_elem; ++i)
00062     {
00063     val += A[i] * B[i] * C[i];
00064     }
00065 
00066   return val;
00067   }
00068 
00069 
00070 
00071 template<typename T1, typename T2>
00072 inline
00073 typename T1::elem_type
00074 op_dot::apply(const Base<typename T1::elem_type,T1>& A_orig, const Base<typename T1::elem_type,T2>& B_orig)
00075   {
00076   arma_extra_debug_sigprint();
00077   
00078   typedef typename T1::elem_type eT;
00079   
00080   const unwrap_to_elem_access<T1> A(A_orig.get_ref());
00081   const unwrap_to_elem_access<T2> B(B_orig.get_ref());
00082 
00083   arma_debug_check( (A.M.n_elem != B.M.n_elem), "dot(): objects must have the same number of elements" );
00084   
00085   const u32 n_elem = A.M.n_elem;
00086   eT val = eT(0);
00087   
00088   for(u32 i=0; i<n_elem; ++i)
00089     {
00090     val += A[i] * B[i];
00091     }
00092   
00093   return val;
00094   }
00095 
00096 
00097 
00098 template<typename T1, typename T2>
00099 inline
00100 typename T1::elem_type
00101 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& A_orig, const Base<typename T1::elem_type,T2>& B_orig)
00102   {
00103   arma_extra_debug_sigprint();
00104   
00105   typedef typename T1::elem_type eT;
00106   
00107   const unwrap_to_elem_access<T1> A(A_orig.get_ref());
00108   const unwrap_to_elem_access<T2> B(B_orig.get_ref());
00109 
00110   arma_debug_check( (A.M.n_elem != B.M.n_elem), "norm_dot(): objects must have the same number of elements" );
00111   
00112   const u32 n_elem = A.M.n_elem;
00113   
00114   eT acc1 = eT(0);
00115   eT acc2 = eT(0);
00116   eT acc3 = eT(0);
00117   
00118   for(u32 i=0; i<n_elem; ++i)
00119     {
00120     const eT tmpA = A[i];
00121     const eT tmpB = B[i];
00122     
00123     acc1 += tmpA * tmpA;
00124     acc2 += tmpB * tmpB;
00125     acc3 += tmpA * tmpB;
00126     }
00127     
00128   return acc3 / ( std::sqrt(acc1 * acc2) );   // TODO: this only makes sense for eT = float, double or complex
00129   }
00130 
00131 
00132 //! @}