00001 00030 #ifndef MOG_GENERIC_H 00031 #define MOG_GENERIC_H 00032 00033 #include <itpp/base/vec.h> 00034 #include <itpp/base/mat.h> 00035 #include <itpp/base/array.h> 00036 00037 00038 namespace itpp { 00039 00056 class MOG_generic { 00057 00058 public: 00059 00065 MOG_generic() { init(); } 00066 00070 MOG_generic(const std::string &name_in) { load(name_in); } 00071 00077 MOG_generic(const int &K_in, const int &D_in, bool full_in=false) { init(K_in, D_in, full_in); } 00078 00086 MOG_generic(Array<vec> &means_in, bool full_in=false) { init(means_in, full_in); } 00087 00094 MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); } 00095 00102 MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); } 00103 00105 virtual ~MOG_generic() { cleanup(); } 00106 00111 void init(); 00112 00118 void init(const int &K_in, const int &D_in, bool full_in=false); 00119 00127 void init(Array<vec> &means_in, bool full_in=false); 00128 00135 void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in); 00136 00143 void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in); 00144 00149 virtual void cleanup(); 00150 00152 bool is_valid() const { return valid; } 00153 00155 bool is_full() const { return full; } 00156 00158 int get_K() const { if(valid) return(K); else return(0); } 00159 00161 int get_D() const { if(valid) return(D); else return(0); } 00162 00164 vec get_weights() const { vec tmp; if(valid) { tmp = weights; } return tmp; } 00165 00167 Array<vec> get_means() const { Array<vec> tmp; if(valid) { tmp = means; } return tmp; } 00168 00170 Array<vec> get_diag_covs() const { Array<vec> tmp; if(valid && !full) { tmp = diag_covs; } return tmp; } 00171 00173 Array<mat> get_full_covs() const { Array<mat> tmp; if(valid && full) { tmp = full_covs; } return tmp; } 00174 00178 void set_means(Array<vec> &means_in); 00179 00183 void set_diag_covs(Array<vec> &diag_covs_in); 00184 00188 void set_full_covs(Array<mat> &full_covs_in); 00189 00193 void set_weights(vec &weights_in); 00194 00196 void set_means_zero(); 00197 00199 void set_diag_covs_unity(); 00200 00202 void set_full_covs_unity(); 00203 00205 void set_weights_uniform(); 00206 00212 void set_checks(bool do_checks_in) { do_checks = do_checks_in; } 00213 00217 void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; } 00218 00222 virtual void load(const std::string &name_in); 00223 00227 virtual void save(const std::string &name_in) const; 00228 00245 virtual void join(const MOG_generic &B_in); 00246 00254 virtual void convert_to_diag(); 00255 00261 virtual void convert_to_full(); 00262 00264 virtual double log_lhood_single_gaus(const vec &x_in, const int k); 00265 00267 virtual double log_lhood(const vec &x_in); 00268 00270 virtual double lhood(const vec &x_in); 00271 00273 virtual double avg_log_lhood(const Array<vec> &X_in); 00274 00275 protected: 00276 00278 bool do_checks; 00279 00281 bool valid; 00282 00284 bool full; 00285 00287 bool paranoid; 00288 00290 int K; 00291 00293 int D; 00294 00296 Array<vec> means; 00297 00299 Array<vec> diag_covs; 00300 00302 Array<mat> full_covs; 00303 00305 vec weights; 00306 00308 double log_max_K; 00309 00315 vec log_det_etc; 00316 00318 vec log_weights; 00319 00321 Array<mat> full_covs_inv; 00322 00324 Array<vec> diag_covs_inv_etc; 00325 00327 bool check_size(const vec &x_in) const; 00328 00330 bool check_size(const Array<vec> &X_in) const; 00331 00333 bool check_array_uniformity(const Array<vec> & A) const; 00334 00336 void set_means_internal(Array<vec> &means_in); 00338 void set_diag_covs_internal(Array<vec> &diag_covs_in); 00340 void set_full_covs_internal(Array<mat> &full_covs_in); 00342 void set_weights_internal(vec &_weigths); 00343 00345 void set_means_zero_internal(); 00347 void set_diag_covs_unity_internal(); 00349 void set_full_covs_unity_internal(); 00351 void set_weights_uniform_internal(); 00352 00354 void convert_to_diag_internal(); 00356 void convert_to_full_internal(); 00357 00359 virtual void setup_means(); 00360 00362 virtual void setup_covs(); 00363 00365 virtual void setup_weights(); 00366 00368 virtual void setup_misc(); 00369 00371 virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k); 00373 virtual double log_lhood_internal(const vec &x_in); 00375 virtual double lhood_internal(const vec &x_in); 00376 00377 private: 00378 vec tmpvecD; 00379 vec tmpvecK; 00380 00381 }; 00382 00383 } // namespace itpp 00384 00385 #endif // #ifndef MOG_GENERIC_H
Generated on Sat Apr 19 11:01:28 2008 for IT++ by Doxygen 1.5.5