blas_proto.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 #ifdef ARMA_USE_BLAS
00018 
00019 //! \namespace blas namespace for BLAS functions
00020 namespace blas
00021   {
00022   extern "C"
00023     {
00024     void sgemv_(const char* transA, const int* m, const int* n, const float*  alpha, const float*  A, const int* ldA, const float*  x, const int* incx, const float*  beta, float*  y, const int* incy);
00025     void dgemv_(const char* transA, const int* m, const int* n, const double* alpha, const double* A, const int* ldA, const double* x, const int* incx, const double* beta, double* y, const int* incy);
00026     void cgemv_(const char* transA, const int* m, const int* n, const void*   alpha, const void*   A, const int* ldA, const void*   x, const int* incx, const void*   beta, void*   y, const int* incy);
00027     void zgemv_(const char* transA, const int* m, const int* n, const void*   alpha, const void*   A, const int* ldA, const void*   x, const int* incx, const void*   beta, void*   y, const int* incy);
00028     
00029     void sgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const float*  alpha, const float*  A, const int* ldA, const float*  B, const int* ldB, const float*  beta, float*  C, const int* ldC);
00030     void dgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const double* alpha, const double* A, const int* ldA, const double* B, const int* ldB, const double* beta, double* C, const int* ldC);
00031     void cgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const void*   alpha, const void*   A, const int* ldA, const void*   B, const int* ldB, const void*   beta, void*   C, const int* ldC);
00032     void zgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const void*   alpha, const void*   A, const int* ldA, const void*   B, const int* ldB, const void*   beta, void*   C, const int* ldC);
00033 
00034 //     float  sdot_(const int* n, const float*  x, const int* incx, const float*  y, const int* incy);
00035 //     double ddot_(const int* n, const double* x, const int* incx, const double* y, const int* incy);
00036 
00037 //     void   dswap_(const int* n, double* x, const int* incx, double* y, const int* incy);
00038 //     void   dscal_(const int* n, const double* alpha, double* x, const int* incx);
00039 //     void   dcopy_(const int* n, const double* x, const int* incx, double* y, const int* incy);
00040 //     void   daxpy_(const int* n, const double* alpha, const double* x, const int* incx, double* y, const int* incy);
00041 //     void    dger_(const int* m, const int* n, const double* alpha, const double* x, const int* incx, const double* y, const int* incy, double* A, const int* ldA);
00042     }
00043   
00044   
00045   
00046 //   template<typename eT>
00047 //   inline
00048 //   eT
00049 //   dot_(const int* n, const eT* x, const int* incx, const eT* y, const int* incy)
00050 //     {
00051 //     arma_type_check<is_supported_blas_type<eT>::value == false>::apply();
00052 //     
00053 //     if(is_float<eT>::value == true)
00054 //       {
00055 //       typedef float T;
00056 //       return sdot_(n, (const T*)x, incx, (const T*)y, incy);
00057 //       }
00058 //     else
00059 //     if(is_double<eT>::value == true)
00060 //       {
00061 //       typedef double T;
00062 //       return ddot_(n, (const T*)x, incx, (const T*)y, incy);
00063 //       }
00064 //     
00065 //     return eT();  // prevent compiler warnings
00066 //     }
00067   
00068   
00069   
00070   template<typename eT>
00071   inline
00072   void
00073   gemv_(const char* transA, const int* m, const int* n, const eT* alpha, const eT* A, const int* ldA, const eT* x, const int* incx, const eT* beta, eT* y, const int* incy)
00074     {
00075     arma_type_check<is_supported_blas_type<eT>::value == false>::apply();
00076     
00077     if(is_float<eT>::value == true)
00078       {
00079       typedef float T;
00080       sgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00081       }
00082     else
00083     if(is_double<eT>::value == true)
00084       {
00085       typedef double T;
00086       dgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00087       }
00088     else
00089     if(is_supported_complex_float<eT>::value == true)
00090       {
00091       typedef std::complex<float> T;
00092       cgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00093       }
00094     else
00095     if(is_supported_complex_double<eT>::value == true)
00096       {
00097       typedef std::complex<double> T;
00098       zgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00099       }
00100     
00101     }
00102   
00103   
00104   
00105   template<typename eT>
00106   inline
00107   void
00108   gemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const eT* alpha, const eT* A, const int* ldA, const eT* B, const int* ldB, const eT* beta, eT* C, const int* ldC)
00109     {
00110     arma_type_check<is_supported_blas_type<eT>::value == false>::apply();
00111     
00112     if(is_float<eT>::value == true)
00113       {
00114       typedef float T;
00115       sgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00116       }
00117     else
00118     if(is_double<eT>::value == true)
00119       {
00120       typedef double T;
00121       dgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00122       }
00123     else
00124     if(is_supported_complex_float<eT>::value == true)
00125       {
00126       typedef std::complex<float> T;
00127       cgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00128       }
00129     else
00130     if(is_supported_complex_double<eT>::value == true)
00131       {
00132       typedef std::complex<double> T;
00133       zgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00134       }
00135     
00136     }
00137   
00138   }
00139 
00140 #endif