gmm.hpp

Go to the documentation of this file.
00001 
00023 #ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP
00024 #define __MLPACK_METHODS_MOG_MOG_EM_HPP
00025 
00026 #include <mlpack/core.hpp>
00027 
00028 // This is the default fitting method class.
00029 #include "em_fit.hpp"
00030 
00031 namespace mlpack {
00032 namespace gmm  {
00033 
00088 template<typename FittingType = EMFit<> >
00089 class GMM
00090 {
00091  private:
00093   size_t gaussians;
00095   size_t dimensionality;
00097   std::vector<arma::vec> means;
00099   std::vector<arma::mat> covariances;
00101   arma::vec weights;
00102 
00103  public:
00107   GMM() :
00108       gaussians(0),
00109       dimensionality(0),
00110       localFitter(FittingType()),
00111       fitter(localFitter)
00112   {
00113     // Warn the user.  They probably don't want to do this.  If this constructor
00114     // is being used (because it is required by some template classes), the user
00115     // should know that it is potentially dangerous.
00116     Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
00117         << "unless parameters are set." << std::endl;
00118   }
00119 
00127   GMM(const size_t gaussians, const size_t dimensionality) :
00128       gaussians(gaussians),
00129       dimensionality(dimensionality),
00130       means(gaussians, arma::vec(dimensionality)),
00131       covariances(gaussians, arma::mat(dimensionality, dimensionality)),
00132       weights(gaussians),
00133       localFitter(FittingType()),
00134       fitter(localFitter) { /* Nothing to do. */ }
00135 
00146   GMM(const size_t gaussians,
00147       const size_t dimensionality,
00148       FittingType& fitter) :
00149       gaussians(gaussians),
00150       dimensionality(dimensionality),
00151       means(gaussians, arma::vec(dimensionality)),
00152       covariances(gaussians, arma::mat(dimensionality, dimensionality)),
00153       weights(gaussians),
00154       fitter(fitter) { /* Nothing to do. */ }
00155 
00163   GMM(const std::vector<arma::vec>& means,
00164       const std::vector<arma::mat>& covariances,
00165       const arma::vec& weights) :
00166       gaussians(means.size()),
00167       dimensionality((!means.empty()) ? means[0].n_elem : 0),
00168       means(means),
00169       covariances(covariances),
00170       weights(weights),
00171       localFitter(FittingType()),
00172       fitter(localFitter) { /* Nothing to do. */ }
00173 
00183   GMM(const std::vector<arma::vec>& means,
00184       const std::vector<arma::mat>& covariances,
00185       const arma::vec& weights,
00186       FittingType& fitter) :
00187       gaussians(means.size()),
00188       dimensionality((!means.empty()) ? means[0].n_elem : 0),
00189       means(means),
00190       covariances(covariances),
00191       weights(weights),
00192       fitter(fitter) { /* Nothing to do. */ }
00193 
00197   template<typename OtherFittingType>
00198   GMM(const GMM<OtherFittingType>& other);
00199 
00204   GMM(const GMM& other);
00205 
00209   template<typename OtherFittingType>
00210   GMM& operator=(const GMM<OtherFittingType>& other);
00211 
00216   GMM& operator=(const GMM& other);
00217 
00224   void Load(const std::string& filename);
00225 
00231   void Save(const std::string& filename) const;
00232 
00234   size_t Gaussians() const { return gaussians; }
00237   size_t& Gaussians() { return gaussians; }
00238 
00240   size_t Dimensionality() const { return dimensionality; }
00243   size_t& Dimensionality() { return dimensionality; }
00244 
00246   const std::vector<arma::vec>& Means() const { return means; }
00248   std::vector<arma::vec>& Means() { return means; }
00249 
00251   const std::vector<arma::mat>& Covariances() const { return covariances; }
00253   std::vector<arma::mat>& Covariances() { return covariances; }
00254 
00256   const arma::vec& Weights() const { return weights; }
00258   arma::vec& Weights() { return weights; }
00259 
00261   const FittingType& Fitter() const { return fitter; }
00263   FittingType& Fitter() { return fitter; }
00264 
00271   double Probability(const arma::vec& observation) const;
00272 
00280   double Probability(const arma::vec& observation,
00281                      const size_t component) const;
00282 
00289   arma::vec Random() const;
00290 
00313   double Estimate(const arma::mat& observations,
00314                   const size_t trials = 1,
00315                   const bool useExistingModel = false);
00316 
00341   double Estimate(const arma::mat& observations,
00342                   const arma::vec& probabilities,
00343                   const size_t trials = 1,
00344                   const bool useExistingModel = false);
00345 
00362   void Classify(const arma::mat& observations,
00363                 arma::Col<size_t>& labels) const;
00364 
00365  private:
00375   double LogLikelihood(const arma::mat& dataPoints,
00376                        const std::vector<arma::vec>& means,
00377                        const std::vector<arma::mat>& covars,
00378                        const arma::vec& weights) const;
00379 
00381   FittingType localFitter;
00382 
00384   FittingType& fitter;
00385 };
00386 
00387 }; // namespace gmm
00388 }; // namespace mlpack
00389 
00390 // Include implementation.
00391 #include "gmm_impl.hpp"
00392 
00393 #endif

Generated on 29 Sep 2016 for MLPACK by  doxygen 1.6.1