MLPACK  1.0.10
svd_batch_learning.hpp
Go to the documentation of this file.
1 
20 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
21 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
22 
23 #include <mlpack/core.hpp>
24 
25 namespace mlpack
26 {
27 namespace amf
28 {
30 {
31  public:
32  SVDBatchLearning(double u = 0.0002,
33  double kw = 0,
34  double kh = 0,
35  double momentum = 0.9,
36  double min = -DBL_MIN,
37  double max = DBL_MAX)
38  : u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
39  {}
40 
41  template<typename MatType>
42  void Initialize(const MatType& dataset, const size_t rank)
43  {
44  const size_t n = dataset.n_rows;
45  const size_t m = dataset.n_cols;
46 
47  mW.zeros(n, rank);
48  mH.zeros(rank, m);
49  }
50 
60  template<typename MatType>
61  inline void WUpdate(const MatType& V,
62  arma::mat& W,
63  const arma::mat& H)
64  {
65  size_t n = V.n_rows;
66  size_t m = V.n_cols;
67 
68  size_t r = W.n_cols;
69 
70  mW = momentum * mW;
71 
72  arma::mat deltaW(n, r);
73  deltaW.zeros();
74 
75  for(size_t i = 0;i < n;i++)
76  {
77  for(size_t j = 0;j < m;j++)
78  {
79  double val;
80  if((val = V(i, j)) != 0)
81  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
82  arma::trans(H.col(j));
83  }
84  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
85  }
86 
87  mW += u * deltaW;
88  W += mW;
89  }
90 
100  template<typename MatType>
101  inline void HUpdate(const MatType& V,
102  const arma::mat& W,
103  arma::mat& H)
104  {
105  size_t n = V.n_rows;
106  size_t m = V.n_cols;
107 
108  size_t r = W.n_cols;
109 
110  mH = momentum * mH;
111 
112  arma::mat deltaH(r, m);
113  deltaH.zeros();
114 
115  for(size_t j = 0;j < m;j++)
116  {
117  for(size_t i = 0;i < n;i++)
118  {
119  double val;
120  if((val = V(i, j)) != 0)
121  deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
122  arma::trans(W.row(i));
123  }
124  if(kh != 0) deltaH.col(j) -= kh * H.col(j);
125  }
126 
127  mH += u*deltaH;
128  H += mH;
129  }
130 
131  private:
132  double u;
133  double kw;
134  double kh;
135  double min;
136  double max;
137  double momentum;
138 
139  arma::mat mW;
140  arma::mat mH;
141 };
142 
143 template<>
144 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
145  arma::mat& W,
146  const arma::mat& H)
147 {
148  size_t n = V.n_rows;
149 
150  size_t r = W.n_cols;
151 
152  mW = momentum * mW;
153 
154  arma::mat deltaW(n, r);
155  deltaW.zeros();
156 
157  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
158  {
159  size_t row = it.row();
160  size_t col = it.col();
161  deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
162  arma::trans(H.col(col));
163  }
164 
165  if(kw != 0) for(size_t i = 0; i < n; i++)
166  {
167  deltaW.row(i) -= kw * W.row(i);
168  }
169 
170  mW += u * deltaW;
171  W += mW;
172 }
173 
174 template<>
175 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
176  const arma::mat& W,
177  arma::mat& H)
178 {
179  size_t m = V.n_cols;
180 
181  size_t r = W.n_cols;
182 
183  mH = momentum * mH;
184 
185  arma::mat deltaH(r, m);
186  deltaH.zeros();
187 
188  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
189  {
190  size_t row = it.row();
191  size_t col = it.col();
192  deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
193  arma::trans(W.row(row));
194  }
195 
196  if(kh != 0) for(size_t j = 0; j < m; j++)
197  {
198  deltaH.col(j) -= kh * H.col(j);
199  }
200 
201  mH += u*deltaH;
202  H += mH;
203 }
204 
205 } // namespace amf
206 } // namespace mlpack
207 
208 
209 #endif
210 
211