hmm.hpp

Go to the documentation of this file.
00001 
00023 #ifndef __MLPACK_METHODS_HMM_HMM_HPP
00024 #define __MLPACK_METHODS_HMM_HMM_HPP
00025 
00026 #include <mlpack/core.hpp>
00027 
00028 namespace mlpack {
00029 namespace hmm  {
00030 
00092 template<typename Distribution = distribution::DiscreteDistribution>
00093 class HMM
00094 {
00095  public:
00110   HMM(const size_t states,
00111       const Distribution emissions,
00112       const double tolerance = 1e-5);
00113 
00135   HMM(const arma::mat& transition,
00136       const std::vector<Distribution>& emission,
00137       const double tolerance = 1e-5);
00138 
00167   void Train(const std::vector<arma::mat>& dataSeq);
00168 
00190   void Train(const std::vector<arma::mat>& dataSeq,
00191              const std::vector<arma::Col<size_t> >& stateSeq);
00192 
00211   double Estimate(const arma::mat& dataSeq,
00212                   arma::mat& stateProb,
00213                   arma::mat& forwardProb,
00214                   arma::mat& backwardProb,
00215                   arma::vec& scales) const;
00216 
00228   double Estimate(const arma::mat& dataSeq,
00229                   arma::mat& stateProb) const;
00230 
00242   void Generate(const size_t length,
00243                 arma::mat& dataSequence,
00244                 arma::Col<size_t>& stateSequence,
00245                 const size_t startState = 0) const;
00246 
00257   double Predict(const arma::mat& dataSeq,
00258                  arma::Col<size_t>& stateSeq) const;
00259 
00266   double LogLikelihood(const arma::mat& dataSeq) const;
00267 
00269   const arma::mat& Transition() const { return transition; }
00271   arma::mat& Transition() { return transition; }
00272 
00274   const std::vector<Distribution>& Emission() const { return emission; }
00276   std::vector<Distribution>& Emission() { return emission; }
00277 
00279   size_t Dimensionality() const { return dimensionality; }
00281   size_t& Dimensionality() { return dimensionality; }
00282 
00284   double Tolerance() const { return tolerance; }
00286   double& Tolerance() { return tolerance; }
00287 
00288  private:
00289   // Helper functions.
00290 
00301   void Forward(const arma::mat& dataSeq,
00302                arma::vec& scales,
00303                arma::mat& forwardProb) const;
00304 
00316   void Backward(const arma::mat& dataSeq,
00317                 const arma::vec& scales,
00318                 arma::mat& backwardProb) const;
00319 
00321   arma::mat transition;
00322 
00324   std::vector<Distribution> emission;
00325 
00327   size_t dimensionality;
00328 
00330   double tolerance;
00331 };
00332 
00333 }; // namespace hmm
00334 }; // namespace mlpack
00335 
00336 // Include implementation.
00337 #include "hmm_impl.hpp"
00338 
00339 #endif

Generated on 29 Sep 2016 for MLPACK by  doxygen 1.6.1