mult_div_update_rules.hpp

Go to the documentation of this file.
00001 
00028 #ifndef __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
00029 #define __MLPACK_METHODS_NMF_MULT_DIV_UPDATE_RULES_HPP
00030 
00031 #include <mlpack/core.hpp>
00032 
00033 namespace mlpack {
00034 namespace nmf {
00035 
00043 class WMultiplicativeDivergenceRule
00044 {
00045  public:
00046   // Empty constructor required for the WUpdateRule template.
00047   WMultiplicativeDivergenceRule() { }
00048 
00057   template<typename MatType>
00058   inline static void Update(const MatType& V,
00059                             arma::mat& W,
00060                             const arma::mat& H)
00061   {
00062     // Simple implementation left in the header file.
00063     arma::mat t1;
00064     arma::rowvec t2;
00065 
00066     t1 = W * H;
00067     for (size_t i = 0; i < W.n_rows; ++i)
00068     {
00069       for (size_t j = 0; j < W.n_cols; ++j)
00070       {
00071         // Writing this as a single expression does not work as of Armadillo
00072         // 3.920.  This should be fixed in a future release, and then the code
00073         // below can be fixed.
00074         //t2 = H.row(j) % V.row(i) / t1.row(i);
00075         t2.set_size(H.n_cols);
00076         for (size_t k = 0; k < t2.n_elem; ++k)
00077         {
00078           t2(k) = H(j, k) * V(i, k) / t1(i, k);
00079         }
00080 
00081         W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
00082       }
00083     }
00084   }
00085 };
00086 
00094 class HMultiplicativeDivergenceRule
00095 {
00096  public:
00097   // Empty constructor required for the HUpdateRule template.
00098   HMultiplicativeDivergenceRule() { }
00099 
00108   template<typename MatType>
00109   inline static void Update(const MatType& V,
00110                             const arma::mat& W,
00111                             arma::mat& H)
00112   {
00113     // Simple implementation left in the header file.
00114     arma::mat t1;
00115     arma::colvec t2;
00116 
00117     t1 = W * H;
00118     for (size_t i = 0; i < H.n_rows; i++)
00119     {
00120       for (size_t j = 0; j < H.n_cols; j++)
00121       {
00122         // Writing this as a single expression does not work as of Armadillo
00123         // 3.920.  This should be fixed in a future release, and then the code
00124         // below can be fixed.
00125         //t2 = W.col(i) % V.col(j) / t1.col(j);
00126         t2.set_size(W.n_rows);
00127         for (size_t k = 0; k < t2.n_elem; ++k)
00128         {
00129           t2(k) = W(k, i) * V(k, j) / t1(k, j);
00130         }
00131 
00132         H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
00133       }
00134     }
00135   }
00136 };
00137 
00138 }; // namespace nmf
00139 }; // namespace mlpack
00140 
00141 #endif

Generated on 29 Sep 2016 for MLPACK by  doxygen 1.6.1