IT++ Logo

gmm.h

Go to the documentation of this file.
00001 
00030 #ifndef GMM_H
00031 #define GMM_H
00032 
00033 #include <itpp/base/mat.h>
00034 
00035 
00036 namespace itpp
00037 {
00038 
00040 
00046 class GMM
00047 {
00048 public:
00049   GMM();
00050   GMM(int nomix, int dim);
00051   GMM(std::string filename);
00052   void init_from_vq(const vec &codebook, int dim);
00053   // void init(const vec &w_in, const vec &m_in, const vec &sigma_in);
00054   void init(const vec &w_in, const mat &m_in, const mat &sigma_in);
00055   void load(std::string filename);
00056   void save(std::string filename);
00057   void set_weight(const vec &weights, bool compflag = true);
00058   void set_weight(int i, double weight, bool compflag = true);
00059   void set_mean(const mat &m_in);
00060   void set_mean(const vec &means, bool compflag = true);
00061   void set_mean(int i, const vec &means, bool compflag = true);
00062   void set_covariance(const mat &sigma_in);
00063   void set_covariance(const vec &covariances, bool compflag = true);
00064   void set_covariance(int i, const vec &covariances, bool compflag = true);
00065   int get_no_mixtures();
00066   int get_no_gaussians() const { return M; }
00067   int get_dimension();
00068   vec get_weight();
00069   double get_weight(int i);
00070   vec get_mean();
00071   vec get_mean(int i);
00072   vec get_covariance();
00073   vec get_covariance(int i);
00074   void marginalize(int d_new);
00075   void join(const GMM &newgmm);
00076   void clear();
00077   double likelihood(const vec &x);
00078   double likelihood_aposteriori(const vec &x, int mixture);
00079   vec likelihood_aposteriori(const vec &x);
00080   vec draw_sample();
00081 protected:
00082   vec   m, sigma, w;
00083   int   M, d;
00084 private:
00085   void  compute_internals();
00086   vec   normweight, normexp;
00087 };
00088 
00089 inline void GMM::set_weight(const vec &weights, bool compflag) {w = weights; if (compflag) compute_internals(); }
00090 inline void GMM::set_weight(int i, double weight, bool compflag) {w(i) = weight; if (compflag) compute_internals(); }
00091 inline void GMM::set_mean(const vec &means, bool compflag) {m = means; if (compflag) compute_internals(); }
00092 inline void GMM::set_covariance(const vec &covariances, bool compflag) {sigma = covariances; if (compflag) compute_internals(); }
00093 inline int GMM::get_no_mixtures()
00094 {
00095   it_warning("GMM::get_no_mixtures(): This function is depreceted and might be removed from feature releases. Please use get_no_gaussians() instead.");
00096   return M;
00097 }
00098 inline int GMM::get_dimension() {return d;}
00099 inline vec GMM::get_weight() {return w;}
00100 inline double GMM::get_weight(int i) {return w(i);}
00101 inline vec GMM::get_mean() {return m;}
00102 inline vec GMM::get_mean(int i) {return m.mid(i*d, d);}
00103 inline vec GMM::get_covariance() {return sigma;}
00104 inline vec GMM::get_covariance(int i) {return sigma.mid(i*d, d);}
00105 
00106 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER = 30, bool VERBOSE = true);
00107 
00109 
00110 } // namespace itpp
00111 
00112 #endif // #ifndef GMM_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Wed Mar 2 2011 22:05:10 for IT++ by Doxygen 1.7.3