您的位置:首页 > 编程语言 > C语言/C++

HMM model 例子 Biased coins

2014-01-22 02:21 435 查看
输入

$ head toss1.txt

H

T

H

T

T

H

T

H

T

T

H for 正面, T for 反面。

可能是正常的 coin,可能是biased coin(HT 的概率非0.5 0.5),二者按照一定几率转换,几率不知道。

输出

TIME    TOSS    P(FAIR) P(BIAS) MLSTATE

1       H       0.3131  0.6869  BIASED

2       T       0.9991  0.0009  FAIR

3       H       0.9788  0.0212  FAIR

4       T       0.9978  0.0022  FAIR

5       T       0.9978  0.0022  FAIR

6       H       0.9788  0.0212  FAIR

7       T       0.9978  0.0022  FAIR

8       H       0.9788  0.0212  FAIR

9       T       0.9978  0.0022  FAIR

10      T       0.9978  0.0022  FAIR

根据 sph.umich  Biostatistics 615课程课件改编。

main.cpp
这里假定训练参数到 0.001截止,否则会过度训练,最多通过 EM算法迭代100次。

#include <iostream>
#include <iomanip>
#include "HMM615.h"

using namespace std;

int main(int argc, char** argv) {
double tran_threshold = 0.001;
vector<int> toss; string tok;
while( cin >> tok ) {
if ( tok == "H" ) toss.push_back(0);
else if ( tok == "T" ) toss.push_back(1);
else {
cerr << "Cannot recognize input " << tok << endl;
return -1;
}
}

int T = toss.size();
HMM615 hmm(2, 2, T);
hmm.trans.data[0][0] = 0.95; hmm.trans.data[0][1] = 0.05;
hmm.trans.data[1][0] = 0.2; hmm.trans.data[1][1] = 0.8;
hmm.emis.data[0][0] = 0.5; hmm.emis.data[0][1] = 0.5;
hmm.emis.data[1][0] = 0.9; hmm.emis.data[1][1] = 0.1;
hmm.pis[0] = 0.5; hmm.pis[1] = 0.5; hmm.outs = toss;
hmm.viterbi();
hmm.BaumWelch();

for (int i=1; i<100; ++i){
if ( hmm.TRANS.data[0][0] < tran_threshold || hmm.TRANS.data[0][1] < tran_threshold || hmm.TRANS.data[1][0] < tran_threshold || hmm.TRANS.data[1][1] < tran_threshold || hmm.EMIS.data[0][0] < tran_threshold || hmm.EMIS.data[0][1] < tran_threshold || hmm.EMIS.data[1][0] < tran_threshold || hmm.EMIS.data[1][1] < tran_threshold ){
break;
}
hmm.trans.data[0][0] = hmm.TRANS.data[0][0];
hmm.trans.data[0][1] = hmm.TRANS.data[0][1];
hmm.trans.data[1][0] = hmm.TRANS.data[1][0];
hmm.trans.data[1][1] = hmm.TRANS.data[1][1];

hmm.emis.data[0][0] = hmm.EMIS.data[0][0];
hmm.emis.data[0][1] = hmm.EMIS.data[0][1];
hmm.emis.data[1][0] = hmm.EMIS.data[1][0];
hmm.emis.data[1][1] = hmm.EMIS.data[1][1];

hmm.zero_sigma();
hmm.viterbi();
hmm.BaumWelch();
}
cout << "TIME\tTOSS\tP(FAIR)\tP(BIAS)\tMLSTATE" << endl;
cout << setiosflags(ios::fixed) << setprecision(4);
for(int t=0; t < T; ++t) {
cout << t+1 << "\t" << (toss[t] == 0 ? "H" : "T") << "\t" << hmm.gammas.data[t][0] << "\t" << hmm.gammas.data[t][1] << "\t" << (hmm.path[t] == 0 ? "FAIR" : "BIASED" ) << endl;
}
cout << hmm.TRANS.data[0][0] << "\t" << hmm.TRANS.data[0][1] << endl;
cout << hmm.TRANS.data[1][0] << "\t" << hmm.TRANS.data[1][1] << endl;
cout << hmm.EMIS.data[0][0] << "\t" << hmm.EMIS.data[0][1] << endl;
cout << hmm.EMIS.data[1][0] << "\t" << hmm.EMIS.data[1][1] << endl;
return 0;
}

HMM615.h 
#ifndef __HMM_615_H
#define __HMM_615_H
#include "Matrix615.h"
#include "MatrixTrible.h"
#include <cmath>

class HMM615 {
public:
// parameters
int nStates; // n : number of possible states
int nObs; // m : number of possible output values
int nTimes; // t : number of time slots with observations
std::vector<double> pis; // initial states
std::vector<int> outs; // observed outcomes
Matrix615<double> trans; // trans[i][j] corresponds to A_{ij}
Matrix615<double> emis;
Matrix615<double> TRANS; // Revised trans and emis
Matrix615<double> EMIS; //

// storages for dynamic programming
Matrix615<double> alphas, betas, gammas, deltas;
Matrix615<int> phis;
MatrixTrible<double> sigma;
std::vector<int> path;
HMM615(int states, int obs, int times) : nStates(states), nObs(obs), nTimes(times), trans(states, states, 0), emis(states, obs, 0), alphas(times, states, 0), betas(times, states, 0),gammas(times, states, 0), deltas(times, states, 0),sigma(times,states,states,0),phis(times, states, 0),TRANS(times,states,0),EMIS(times,states,0)
{
pis.resize(nStates);
path.resize(nTimes);
}
void forward(); // given below
void backward(); //
void forwardBackward(); // given below
void BaumWelch();
void zero_sigma();
void viterbi(); //
};
#endif // __HMM_615_H

void HMM615::zero_sigma(){
for(int t=0; t < nTimes-1; ++t) {
for(int i=0; i < nStates; ++i){
for (int j=0; j < nStates; ++j){
sigma.data[t][i][j] = 0;
}
}
}
}

void HMM615::forward(){
for(int i=0; i < nStates; ++i){
double tmp = std::log(pis[i]) + std::log(emis.data[i][outs[0]]);
alphas.data[0][i] = std::exp(tmp);
}
for(int t=1; t < nTimes; ++t){
for(int i=0; i < nStates; ++i){
alphas.data[t][i] = 0;
for(int j=0; j < nStates; ++j) {
double tmp = std::log( alphas.data[t-1][j] ) + std::log( trans.data[j][i] ) + std::log( emis.data[i][outs[t]]) ;
alphas.data[t][i] += std::exp(tmp);
}
}
}
}

void HMM615::backward() {
for(int i=0; i < nStates; ++i) {
betas.data[nTimes-1][i] = 1;
}
for(int t=nTimes-2; t >=0; --t) {
for(int i=0; i < nStates; ++i) {
betas.data[t][i] = 0;
for(int j=0; j < nStates; ++j) {
double tmp = std::log( betas.data[t+1][j] ) + std::log( trans.data[i][j] ) + std::log( emis.data[j][outs[t+1]]) ;
betas.data[t][i] += std::exp(tmp);
}
}
}
}

void HMM615::forwardBackward() {
forward();
backward();
for(int t=0; t < nTimes; ++t) {
double sum = 0;
for(int i=0; i < nStates; ++i) {
double tmp = std::log( alphas.data[t][i] ) + std::log( betas.data[t][i] );
sum += std::exp(tmp);
}
for(int i=0; i < nStates; ++i) {
double tmp = std::log( alphas.data[t][i] ) + std::log( betas.data[t][i] ) - std::log( sum );
gammas.data[t][i] = std::exp(tmp);
}
}
for(int t=0; t < nTimes-1; ++t) {
double sum = 0;
for(int i=0; i < nStates; ++i){
for (int j=0; j < nStates; ++j){
double tmp = std::log( alphas.data[t][i] ) + std::log( trans.data[i][j] ) + std::log( emis.data[j][outs[t+1]] ) + std::log( betas.data[t+1][j] );
sum += std::exp(tmp);
}
}
for(int i=0; i < nStates; ++i){
for (int j=0; j < nStates; ++j){
double tmp = std::log( alphas.data[t][i] ) + std::log( trans.data[i][j] ) + std::log( emis.data[j][outs[t+1]] ) + std::log( betas.data[t+1][j] ) - std::log( sum );
sigma.data[t][i][j] += std::exp(tmp);
}
}
}
}

void HMM615::BaumWelch() {
forwardBackward();
for (int i=0; i<nStates; ++i){
double sum_gamma1 = 0;
for (int t=0; t<nTimes-1; ++t){
sum_gamma1 += gammas.data[t][i];
}
for (int j=0; j<nStates; ++j){
double sum_sigma1 = 0;
for (int t=0; t<nTimes-1; ++t){
sum_sigma1 += sigma.data[t][i][j];
}
TRANS.data[i][j] = sum_sigma1 / sum_gamma1;
}
}

for (int i=0; i<nStates; ++i){
for (int k=0; k<nObs; ++k){
double sum_gamma2 = 0;
double sum_gamma3 = 0;
for (int t=0; t<nTimes; ++t){
sum_gamma2 += gammas.data[t][i];
if (outs[t]==k){
sum_gamma3 += gammas.data[t][i];
}
}
EMIS.data[i][k]= sum_gamma3 / sum_gamma2;
}
}
}

void HMM615::viterbi() {
for(int i=0; i < nStates; ++i) {
deltas.data[0][i] = pis[i] * emis.data[i][ outs[0] ];
}
for(int t=1; t < nTimes; ++t) {
for(int i=0; i < nStates; ++i) {
int maxIdx = 0;
double tmp = std::log( deltas.data[t-1][0] ) + std::log( trans.data[0][i] ) + std::log( emis.data[i][ outs[t] ] );
double maxVal = std::exp(tmp);
for(int j=1; j < nStates; ++j) {
double tmp2 = std::log( deltas.data[t-1][j] ) + std::log( trans.data[j][i] ) + std::log( emis.data[i][ outs[t] ] );
double val = std::exp( tmp2 );
if ( val > maxVal ) {
maxIdx = j;
maxVal = val;
}
}
deltas.data[t][i] = maxVal;
phis.data[t][i] = maxIdx;
}
}
double maxDelta = deltas.data[nTimes-1][0];
path[nTimes-1] = 0;
for(int i=1; i < nStates; ++i){
if ( maxDelta < deltas.data[nTimes-1][i] ) {
maxDelta = deltas.data[nTimes-1][i];
path[nTimes-i] = i;
}
}
for(int t=nTimes-2; t >= 0; --t){
path[t] = phis.data[t+1][ path[t+1] ];
}
}

Matrix615.h
#ifndef __MATRIX_615_H
#define __MATRIX_615_H
#include <vector>

template <class T>
class Matrix615 {
public:
std::vector< std::vector<T> > data;
Matrix615(int nrow, int ncol, T val = 0) {
data.resize(nrow); // make n rows
for(int i=0; i < nrow; ++i) {
data[i].resize(ncol,val); // make n cols with default value val
}
}
int rowNums() { return (int)data.size(); }
int colNums() { return ( data.size() == 0 ) ? 0 : (int)data[0].size(); }
};
#endif // __MATRIX_615_H

MatrixTrible.h
#ifndef __MATRIX_TRIBLE_H
#define __MATRIX_TRIBLE_H
#include <vector>

template <class T>
class MatrixTrible {
public:
std::vector< std::vector< std::vector<T> > > data;
MatrixTrible(int nrow, int ncol, int ntri, T val = 0) {
data.resize(nrow); // make n rows
for(int i=0; i < nrow; ++i) {
data[i].resize(ncol); // make n cols with default value val
for (int j=0; j <
9f3c
ncol; ++j){
data[i][j].resize(ntri,val);
}
}
}
int rowNums() { return (int)data.size(); }
int colNums() { return ( data.size() == 0 ) ? 0 : (int)data[0].size(); }
int triNums() { return ( data.size() == 0 ) ? 0 : ( ( data[0].size() == 0) ? 0 : (int)data[0][0].size() ); }
};
#endif // __MATRIX_TRIBLE_HMakefile
cc=g++
obj=main.o
exe=hmm
$(exe):$(obj)
$(cc) -o $(exe) $(obj)

main.o:main.cpp HMM615.h
$(cc) -c main.cpp

clean:
rm -rf *.o $(exe)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  EM 算法 HMM C++