00001 00030 #include <itpp/srccode/gmm.h> 00031 #include <itpp/srccode/vqtrain.h> 00032 #include <itpp/base/math/elem_math.h> 00033 #include <itpp/base/matfunc.h> 00034 #include <itpp/base/specmat.h> 00035 #include <itpp/base/random.h> 00036 #include <itpp/base/timing.h> 00037 #include <iostream> 00038 #include <fstream> 00039 00041 00042 namespace itpp { 00043 00044 GMM::GMM() 00045 { 00046 d=0; 00047 M=0; 00048 } 00049 00050 GMM::GMM(std::string filename) 00051 { 00052 load(filename); 00053 } 00054 00055 GMM::GMM(int M_in, int d_in) 00056 { 00057 M=M_in; 00058 d=d_in; 00059 m=zeros(M*d); 00060 sigma=zeros(M*d); 00061 w=1./M*ones(M); 00062 00063 for (int i=0;i<M;i++) { 00064 w(i)=1.0/M; 00065 } 00066 compute_internals(); 00067 } 00068 00069 void GMM::init_from_vq(const vec &codebook, int dim) 00070 { 00071 00072 mat C(dim,dim); 00073 int i; 00074 vec v; 00075 00076 d=dim; 00077 M=codebook.length()/dim; 00078 00079 m=codebook; 00080 w=ones(M)/double(M); 00081 00082 C.clear(); 00083 for (i=0;i<M;i++) { 00084 v=codebook.mid(i*d,d); 00085 C=C+outer_product(v,v); 00086 } 00087 C=1./M*C; 00088 sigma.set_length(M*d); 00089 for (i=0;i<M;i++) { 00090 sigma.replace_mid(i*d,diag(C)); 00091 } 00092 00093 compute_internals(); 00094 } 00095 00096 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in) 00097 { 00098 int i,j; 00099 d=m_in.rows(); 00100 M=m_in.cols(); 00101 00102 m.set_length(M*d); 00103 sigma.set_length(M*d); 00104 for (i=0;i<M;i++) { 00105 for (j=0;j<d;j++) { 00106 m(i*d+j)=m_in(j,i); 00107 sigma(i*d+j)=sigma_in(j,i); 00108 } 00109 } 00110 w=w_in; 00111 00112 compute_internals(); 00113 } 00114 00115 void GMM::set_mean(const mat &m_in) 00116 { 00117 int i,j; 00118 00119 d=m_in.rows(); 00120 M=m_in.cols(); 00121 00122 m.set_length(M*d); 00123 for (i=0;i<M;i++) { 00124 for (j=0;j<d;j++) { 00125 m(i*d+j)=m_in(j,i); 00126 } 00127 } 00128 compute_internals(); 00129 } 00130 00131 void GMM::set_mean(int i, const vec &means, bool compflag) 00132 { 00133 m.replace_mid(i*length(means),means); 00134 if (compflag) compute_internals(); 00135 } 00136 00137 void GMM::set_covariance(const mat &sigma_in) 00138 { 00139 int i,j; 00140 00141 d=sigma_in.rows(); 00142 M=sigma_in.cols(); 00143 00144 sigma.set_length(M*d); 00145 for (i=0;i<M;i++) { 00146 for (j=0;j<d;j++) { 00147 sigma(i*d+j)=sigma_in(j,i); 00148 } 00149 } 00150 compute_internals(); 00151 } 00152 00153 void GMM::set_covariance(int i, const vec &covariances, bool compflag) 00154 { 00155 sigma.replace_mid(i*length(covariances),covariances); 00156 if (compflag) compute_internals(); 00157 } 00158 00159 void GMM::marginalize(int d_new) 00160 { 00161 it_error_if(d_new>d,"GMM.marginalize: cannot change to a larger dimension"); 00162 00163 vec mnew(d_new*M),sigmanew(d_new*M); 00164 int i,j; 00165 00166 for (i=0;i<M;i++) { 00167 for (j=0;j<d_new;j++) { 00168 mnew(i*d_new+j)=m(i*d+j); 00169 sigmanew(i*d_new+j)=sigma(i*d+j); 00170 } 00171 } 00172 m=mnew; 00173 sigma=sigmanew; 00174 d=d_new; 00175 00176 compute_internals(); 00177 } 00178 00179 void GMM::join(const GMM &newgmm) 00180 { 00181 if (d==0) { 00182 w=newgmm.w; 00183 m=newgmm.m; 00184 sigma=newgmm.sigma; 00185 d=newgmm.d; 00186 M=newgmm.M; 00187 } else { 00188 it_error_if( d!=newgmm.d,"GMM.join: cannot join GMMs of different dimension"); 00189 00190 w=concat(double(M)/(M+newgmm.M)*w,double(newgmm.M)/(M+newgmm.M)*newgmm.w); 00191 w=w/sum(w); 00192 m=concat(m,newgmm.m); 00193 sigma=concat(sigma,newgmm.sigma); 00194 00195 M=M+newgmm.M; 00196 } 00197 compute_internals(); 00198 } 00199 00200 void GMM::clear() 00201 { 00202 w.set_length(0); 00203 m.set_length(0); 00204 sigma.set_length(0); 00205 d=0; 00206 M=0; 00207 } 00208 00209 void GMM::save(std::string filename) 00210 { 00211 std::ofstream f(filename.c_str()); 00212 int i,j; 00213 00214 f << M << " " << d << std::endl ; 00215 for (i=0;i<w.length();i++) { 00216 f << w(i) << std::endl ; 00217 } 00218 for (i=0;i<M;i++) { 00219 f << m(i*d) ; 00220 for (j=1;j<d;j++) { 00221 f << " " << m(i*d+j) ; 00222 } 00223 f << std::endl ; 00224 } 00225 for (i=0;i<M;i++) { 00226 f << sigma(i*d) ; 00227 for (j=1;j<d;j++) { 00228 f << " " << sigma(i*d+j) ; 00229 } 00230 f << std::endl ; 00231 } 00232 } 00233 00234 void GMM::load(std::string filename) 00235 { 00236 std::ifstream GMMFile(filename.c_str()); 00237 int i,j; 00238 00239 it_error_if(!GMMFile,std::string("GMM::load : cannot open file ")+filename); 00240 00241 GMMFile >> M >> d ; 00242 00243 00244 w.set_length(M); 00245 for (i=0;i<M;i++) { 00246 GMMFile >> w(i) ; 00247 } 00248 m.set_length(M*d); 00249 for (i=0;i<M;i++) { 00250 for (j=0;j<d;j++) { 00251 GMMFile >> m(i*d+j) ; 00252 } 00253 } 00254 sigma.set_length(M*d); 00255 for (i=0;i<M;i++) { 00256 for (j=0;j<d;j++) { 00257 GMMFile >> sigma(i*d+j) ; 00258 } 00259 } 00260 compute_internals(); 00261 std::cout << " mixtures:" << M << " dim:" << d << std::endl ; 00262 } 00263 00264 double GMM::likelihood(const vec &x) 00265 { 00266 double fx=0; 00267 int i; 00268 00269 for (i=0;i<M;i++) { 00270 fx+=w(i)*likelihood_aposteriori(x, i); 00271 } 00272 return fx; 00273 } 00274 00275 vec GMM::likelihood_aposteriori(const vec &x) 00276 { 00277 vec v(M); 00278 int i; 00279 00280 for (i=0;i<M;i++) { 00281 v(i)=w(i)*likelihood_aposteriori(x, i); 00282 } 00283 return v; 00284 } 00285 00286 double GMM::likelihood_aposteriori(const vec &x, int mixture) 00287 { 00288 int j; 00289 double s; 00290 00291 it_error_if(d!=x.length(),"GMM::likelihood_aposteriori : dimensions does not match"); 00292 s=0; 00293 for (j=0;j<d;j++) { 00294 s+=normexp(mixture*d+j)*sqr(x(j)-m(mixture*d+j)); 00295 } 00296 return normweight(mixture)*std::exp(s);; 00297 } 00298 00299 void GMM::compute_internals() 00300 { 00301 int i,j; 00302 double s; 00303 double constant=1.0/std::pow(2*pi,d/2.0); 00304 00305 normweight.set_length(M); 00306 normexp.set_length(M*d); 00307 00308 for (i=0;i<M;i++) { 00309 s=1; 00310 for (j=0;j<d;j++) { 00311 normexp(i*d+j)=-0.5/sigma(i*d+j); // check time 00312 s*=sigma(i*d+j); 00313 } 00314 normweight(i) = constant/std::sqrt(s); 00315 } 00316 00317 } 00318 00319 vec GMM::draw_sample() 00320 { 00321 static bool first=true; 00322 static vec cumweight; 00323 double u=randu(); 00324 int k; 00325 00326 if (first) { 00327 first=false; 00328 cumweight=cumsum(w); 00329 it_error_if(std::abs(cumweight(length(cumweight)-1)-1)>1e-6,"weight does not sum to 0"); 00330 cumweight(length(cumweight)-1)=1; 00331 } 00332 k=0; 00333 while (u>cumweight(k)) k++; 00334 00335 return elem_mult(sqrt(sigma.mid(k*d,d)),randn(d))+m.mid(k*d,d); 00336 } 00337 00338 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE) 00339 { 00340 mat mean; 00341 int i,j,d=TrainingData(0).length(); 00342 vec sig; 00343 GMM gmm(M,d); 00344 vec m(d*M); 00345 vec sigma(d*M); 00346 vec w(M); 00347 vec normweight(M); 00348 vec normexp(d*M); 00349 double LL=0,LLold,fx; 00350 double constant=1.0/std::pow(2*pi,d/2.0); 00351 int T=TrainingData.length(); 00352 vec x1; 00353 int t,n; 00354 vec msum(d*M); 00355 vec sigmasum(d*M); 00356 vec wsum(M); 00357 vec p_aposteriori(M); 00358 vec x2; 00359 double s; 00360 vec temp1,temp2; 00361 //double MINIMUM_VARIANCE=0.03; 00362 00363 //-----------initialization----------------------------------- 00364 00365 mean=vqtrain(TrainingData,M,200000,0.5,VERBOSE); 00366 for (i=0;i<M;i++) gmm.set_mean(i,mean.get_col(i),false); 00367 // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false); 00368 sig=zeros(d); 00369 for (i=0;i<TrainingData.length();i++) sig+=sqr(TrainingData(i)); 00370 sig/=TrainingData.length(); 00371 for (i=0;i<M;i++) gmm.set_covariance(i,0.5*sig,false); 00372 00373 gmm.set_weight(1.0/M*ones(M)); 00374 00375 //-----------optimization----------------------------------- 00376 00377 tic(); 00378 for (i=0;i<M;i++) { 00379 temp1=gmm.get_mean(i); 00380 temp2=gmm.get_covariance(i); 00381 for (j=0;j<d;j++) { 00382 m(i*d+j)=temp1(j); 00383 sigma(i*d+j)=temp2(j); 00384 } 00385 w(i)=gmm.get_weight(i); 00386 } 00387 for (n=0;n<NOITER;n++) { 00388 for (i=0;i<M;i++) { 00389 s=1; 00390 for (j=0;j<d;j++) { 00391 normexp(i*d+j)=-0.5/sigma(i*d+j); // check time 00392 s*=sigma(i*d+j); 00393 } 00394 normweight(i) = constant*w(i)/std::sqrt(s); 00395 } 00396 LLold=LL; 00397 wsum.clear(); 00398 msum.clear(); 00399 sigmasum.clear(); 00400 LL=0; 00401 for (t=0;t<T;t++) { 00402 x1=TrainingData(t); 00403 x2=sqr(x1); 00404 fx=0; 00405 for (i=0;i<M;i++) { 00406 s=0; 00407 for (j=0;j<d;j++) { 00408 s+=normexp(i*d+j)*sqr(x1(j)-m(i*d+j)); 00409 } 00410 p_aposteriori(i)=normweight(i)*std::exp(s); 00411 fx+=p_aposteriori(i); 00412 } 00413 p_aposteriori/=fx; 00414 LL=LL+std::log(fx); 00415 00416 for (i=0;i<M;i++) { 00417 wsum(i)+=p_aposteriori(i); 00418 for (j=0;j<d;j++) { 00419 msum(i*d+j)+=p_aposteriori(i)*x1(j); 00420 sigmasum(i*d+j)+=p_aposteriori(i)*x2(j); 00421 } 00422 } 00423 } 00424 for (i=0;i<M;i++) { 00425 for (j=0;j<d;j++) { 00426 m(i*d+j)=msum(i*d+j)/wsum(i); 00427 sigma(i*d+j)=sigmasum(i*d+j)/wsum(i)-sqr(m(i*d+j)); 00428 } 00429 w(i)=wsum(i)/T; 00430 } 00431 LL=LL/T; 00432 00433 if (std::abs((LL-LLold)/LL) < 1e-6) break; 00434 if (VERBOSE) { 00435 std::cout << n << ": " << LL << " " << std::abs((LL-LLold)/LL) << " " << toc() << std::endl ; 00436 std::cout << "---------------------------------------" << std::endl ; 00437 tic(); 00438 } else { 00439 std::cout << n << ": LL = " << LL << " " << std::abs((LL-LLold)/LL) << "\r" ;std::cout.flush(); 00440 } 00441 } 00442 for (i=0;i<M;i++) { 00443 gmm.set_mean(i,m.mid(i*d,d),false); 00444 gmm.set_covariance(i,sigma.mid(i*d,d),false); 00445 } 00446 gmm.set_weight(w); 00447 return gmm; 00448 } 00449 00450 } // namespace itpp 00451
Generated on Sun Dec 9 17:30:27 2007 for IT++ by Doxygen 1.5.4