14 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP 15 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP 70 template<
typename MatType>
71 void Initialize(
const MatType& dataset,
const size_t rank)
73 const size_t n = dataset.n_rows;
74 const size_t m = dataset.n_cols;
89 template<
typename MatType>
105 for (
size_t i = 0; i < n; i++)
107 for (
size_t j = 0; j < m; j++)
109 const double val = V(i, j);
111 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
112 arma::trans(H.col(j));
116 deltaW.row(i) -=
kw * W.row(i);
134 template<
typename MatType>
150 for (
size_t j = 0; j < m; j++)
152 for (
size_t i = 0; i < n; i++)
154 const double val = V(i, j);
156 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
160 deltaH.col(j) -=
kh * H.col(j);
170 template<
typename Archive>
205 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(
const arma::sp_mat& V,
209 const size_t n = V.n_rows;
210 const size_t r = W.n_cols;
217 for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
219 const size_t row = it.row();
220 const size_t col = it.col();
221 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
222 arma::trans(H.col(col));
233 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(
const arma::sp_mat& V,
237 const size_t m = V.n_cols;
238 const size_t r = W.n_cols;
245 for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
247 const size_t row = it.row();
248 const size_t col = it.col();
249 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
263 #endif // __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
double kh
Regularization parameter for matrix H.
void Serialize(Archive &ar, const unsigned int)
Serialize the SVDBatch object.
double u
Step size of the algorithm.
Linear algebra utility functions, generally performed on matrices or vectors.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9)
SVD Batch learning constructor.
void Initialize(const MatType &dataset, const size_t rank)
Initialize parameters before factorization.
double kw
Regularization parameter for matrix W.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
arma::mat mW
Momentum matrix for matrix W.
arma::mat mH
Momentum matrix for matrix H.
double momentum
Momentum value (between 0 and 1).