00001 00031 #include <itpp/stat/mog_diag_em.h> 00032 #include <itpp/base/math/log_exp.h> 00033 #include <itpp/base/timing.h> 00034 00035 #include <iostream> 00036 #include <iomanip> 00037 00038 namespace itpp { 00039 00041 void inline MOG_diag_EM_sup::update_internals() { 00042 00043 double Ddiv2_log_2pi = D/2.0 * std::log(m_2pi); 00044 00045 for(int k=0;k<K;k++) c_log_weights[k] = std::log(c_weights[k]); 00046 00047 for(int k=0;k<K;k++) { 00048 double acc = 0.0; 00049 double * c_diag_cov = c_diag_covs[k]; 00050 double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k]; 00051 00052 for(int d=0;d<D;d++) { 00053 double tmp = c_diag_cov[d]; 00054 c_diag_cov_inv_etc[d] = 1.0/(2.0*tmp); 00055 acc += std::log(tmp); 00056 } 00057 00058 c_log_det_etc[k] = -Ddiv2_log_2pi - 0.5*acc; 00059 } 00060 00061 } 00062 00063 00065 void inline MOG_diag_EM_sup::sanitise_params() { 00066 00067 double acc = 0.0; 00068 for(int k=0;k<K;k++) { 00069 if(c_weights[k] < weight_floor) c_weights[k] = weight_floor; 00070 if(c_weights[k] > 1.0) c_weights[k] = 1.0; 00071 acc += c_weights[k]; 00072 } 00073 for(int k=0;k<K;k++) c_weights[k] /= acc; 00074 00075 for(int k=0;k<K;k++) 00076 for(int d=0;d<D;d++) 00077 if(c_diag_covs[k][d] < var_floor) c_diag_covs[k][d] = var_floor; 00078 00079 } 00080 00082 double MOG_diag_EM_sup::ml_update_params() { 00083 00084 double acc_loglhood = 0.0; 00085 00086 for(int k=0;k<K;k++) { 00087 c_acc_loglhood_K[k] = 0.0; 00088 00089 double * c_acc_mean = c_acc_means[k]; 00090 double * c_acc_cov = c_acc_covs[k]; 00091 00092 for(int d=0;d<D;d++) { c_acc_mean[d] = 0.0; c_acc_cov[d] = 0.0; } 00093 } 00094 00095 for(int n=0;n<N;n++) { 00096 double * c_x = c_X[n]; 00097 00098 bool danger = paranoid; 00099 for(int k=0;k<K;k++) { 00100 double tmp = c_log_weights[k] + MOG_diag::log_lhood_single_gaus_internal(c_x, k); 00101 c_tmpvecK[k] = tmp; 00102 if(tmp >= log_max_K) danger = true; 00103 } 00104 00105 if(danger) { 00106 00107 double log_sum = c_tmpvecK[0]; for(int k=1;k<K;k++) log_sum = log_add( log_sum, c_tmpvecK[k] ); 00108 acc_loglhood += log_sum; 00109 00110 for(int k=0;k<K;k++) { 00111 00112 double * c_acc_mean = c_acc_means[k]; 00113 double * c_acc_cov = c_acc_covs[k]; 00114 00115 double tmp_k = trunc_exp(c_tmpvecK[k] - log_sum); 00116 acc_loglhood_K[k] += tmp_k; 00117 00118 for(int d=0;d<D;d++) { 00119 double tmp_x = c_x[d]; 00120 c_acc_mean[d] += tmp_k * tmp_x; 00121 c_acc_cov[d] += tmp_k * tmp_x*tmp_x; 00122 } 00123 } 00124 } 00125 else { 00126 00127 double sum = 0.0; for(int k=0;k<K;k++) { double tmp = std::exp(c_tmpvecK[k]); c_tmpvecK[k] = tmp; sum += tmp; } 00128 acc_loglhood += std::log(sum); 00129 00130 for(int k=0;k<K;k++) { 00131 00132 double * c_acc_mean = c_acc_means[k]; 00133 double * c_acc_cov = c_acc_covs[k]; 00134 00135 double tmp_k = c_tmpvecK[k] / sum; 00136 c_acc_loglhood_K[k] += tmp_k; 00137 00138 for(int d=0;d<D;d++) { 00139 double tmp_x = c_x[d]; 00140 c_acc_mean[d] += tmp_k * tmp_x; 00141 c_acc_cov[d] += tmp_k * tmp_x*tmp_x; 00142 } 00143 } 00144 } 00145 } 00146 00147 for(int k=0;k<K;k++) { 00148 00149 double * c_mean = c_means[k]; 00150 double * c_diag_cov = c_diag_covs[k]; 00151 00152 double * c_acc_mean = c_acc_means[k]; 00153 double * c_acc_cov = c_acc_covs[k]; 00154 00155 double tmp_k = c_acc_loglhood_K[k]; 00156 00157 c_weights[k] = tmp_k / N; 00158 00159 for(int d=0;d<D;d++) { 00160 double tmp_mean = c_acc_mean[d] / tmp_k; 00161 c_mean[d] = tmp_mean; 00162 c_diag_cov[d] = c_acc_cov[d] / tmp_k - tmp_mean*tmp_mean; 00163 } 00164 } 00165 00166 return(acc_loglhood/N); 00167 00168 } 00169 00170 00171 void MOG_diag_EM_sup::ml_iterate() { 00172 using std::cout; 00173 using std::endl; 00174 using std::setw; 00175 using std::showpos; 00176 using std::noshowpos; 00177 using std::scientific; 00178 using std::fixed; 00179 using std::flush; 00180 using std::setprecision; 00181 00182 double avg_log_lhood_old = -1.0*std::numeric_limits<double>::max(); 00183 00184 Real_Timer tt; 00185 00186 if(verbose) { 00187 cout << "MOG_diag_EM_sup::ml_iterate()" << endl; 00188 cout << setw(14) << "iteration"; 00189 cout << setw(14) << "avg_loglhood"; 00190 cout << setw(14) << "delta"; 00191 cout << setw(10) << "toc"; 00192 cout << endl; 00193 } 00194 00195 for(int i=0; i<max_iter; i++) { 00196 sanitise_params(); 00197 update_internals(); 00198 00199 if(verbose) tt.tic(); 00200 double avg_log_lhood_new = ml_update_params(); 00201 00202 if(verbose) { 00203 double delta = avg_log_lhood_new - avg_log_lhood_old; 00204 00205 cout << noshowpos << fixed; 00206 cout << setw(14) << i; 00207 cout << showpos << scientific << setprecision(3); 00208 cout << setw(14) << avg_log_lhood_new; 00209 cout << setw(14) << delta; 00210 cout << noshowpos << fixed; 00211 cout << setw(10) << tt.toc(); 00212 cout << endl << flush; 00213 } 00214 00215 if(avg_log_lhood_new <= avg_log_lhood_old) break; 00216 00217 avg_log_lhood_old = avg_log_lhood_new; 00218 } 00219 } 00220 00221 00222 void MOG_diag_EM_sup::ml(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) { 00223 00224 it_assert(model_in.is_valid(), "MOG_diag_EM_sup::ml(): initial model not valid" ); 00225 it_assert(check_array_uniformity(X_in), "MOG_diag_EM_sup::ml(): 'X' is empty or contains vectors of varying dimensionality" ); 00226 it_assert( (max_iter_in > 0), "MOG_diag_EM_sup::ml(): 'max_iter' needs to be greater than zero" ); 00227 00228 verbose = verbose_in; 00229 00230 N = X_in.size(); 00231 00232 Array<vec> means_in = model_in.get_means(); 00233 Array<vec> diag_covs_in = model_in.get_diag_covs(); 00234 vec weights_in = model_in.get_weights(); 00235 00236 init(means_in, diag_covs_in, weights_in); 00237 00238 means_in.set_size(0); diag_covs_in.set_size(0); weights_in.set_size(0); 00239 00240 if(K > N) it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N"); 00241 else 00242 if(K > N/10) it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N/10"); 00243 00244 var_floor = var_floor_in; 00245 weight_floor = weight_floor_in; 00246 00247 const double tiny = std::numeric_limits<double>::min(); 00248 if(var_floor < tiny) var_floor = tiny; 00249 if(weight_floor < tiny) weight_floor = tiny; 00250 if(weight_floor > 1.0/K ) weight_floor = 1.0/K; 00251 00252 max_iter = max_iter_in; 00253 00254 tmpvecK.set_size(K); 00255 tmpvecD.set_size(D); 00256 acc_loglhood_K.set_size(K); 00257 00258 acc_means.set_size(K); for(int k=0;k<K;k++) acc_means(k).set_size(D); 00259 acc_covs.set_size(K); for(int k=0;k<K;k++) acc_covs(k).set_size(D); 00260 00261 c_X = enable_c_access(X_in); 00262 c_tmpvecK = enable_c_access(tmpvecK); 00263 c_tmpvecD = enable_c_access(tmpvecD); 00264 c_acc_loglhood_K = enable_c_access(acc_loglhood_K); 00265 c_acc_means = enable_c_access(acc_means); 00266 c_acc_covs = enable_c_access(acc_covs); 00267 00268 ml_iterate(); 00269 00270 model_in.init(means, diag_covs, weights); 00271 00272 disable_c_access(c_X); 00273 disable_c_access(c_tmpvecK); 00274 disable_c_access(c_tmpvecD); 00275 disable_c_access(c_acc_loglhood_K); 00276 disable_c_access(c_acc_means); 00277 disable_c_access(c_acc_covs); 00278 00279 00280 tmpvecK.set_size(0); 00281 tmpvecD.set_size(0); 00282 acc_loglhood_K.set_size(0); 00283 acc_means.set_size(0); 00284 acc_covs.set_size(0); 00285 00286 cleanup(); 00287 00288 } 00289 00290 void MOG_diag_EM_sup::map(MOG_diag &model_in, MOG_diag &prior_model_in, Array<vec> &X_in, int max_iter_in, double alpha_in, double var_floor_in, double weight_floor_in, bool verbose_in) { 00291 it_assert(false, "MOG_diag_EM_sup::map(): not implemented yet"); 00292 } 00293 00294 00295 // 00296 // convenience functions 00297 00298 void MOG_diag_ML(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) { 00299 MOG_diag_EM_sup EM; 00300 EM.ml(model_in, X_in, max_iter_in, var_floor_in, weight_floor_in, verbose_in); 00301 } 00302 00303 void MOG_diag_MAP(MOG_diag &model_in, MOG_diag &prior_model_in, Array<vec> &X_in, int max_iter_in, double alpha_in, double var_floor_in, double weight_floor_in, bool verbose_in) { 00304 it_assert(false, "MOG_diag_MAP(): not implemented yet"); 00305 } 00306 00307 } 00308
Generated on Sat Apr 19 10:41:57 2008 for IT++ by Doxygen 1.5.5