隐马尔科夫模型的解码算法和前向算法

来源:互联网 发布:php 工作流 开源 编辑:程序博客网 时间:2024/05/22 22:24


隐马尔科夫模型的解码算法和前向算法

隐马尔科夫模型算是语音识别的支柱了,观察序列是语音信号的MFCC编码,得到的隐藏序列就是音素了。音素加上神经网络就可以构建一个语音识别系统。最近也在学习这个方面,所以试着先把解码算法(维特比算法)和计算观察序列出现概率的前向算法实现了一下。用的语言是Java,感觉IDEA比eclipse舒心。

其中前行算法的代码是在解码算法的基础上简单改的,在时间和空间上还有很多可以优化的地方,但是并没有这个需求啊微笑


关于HMM模型的资料,找到了这两篇很好的文章,真是非常感谢作者啊

贝叶斯网络简介 : http://blog.csdn.net/memory513773348/article/details/16973807

HMM学习最佳范例: http://www.52nlp.cn/hmm-learn-best-practices-seven-forward-backward-algorithm-5


HMM和贝叶斯网络属于概率图模型,相关的教材有Koller的经典著作《Porbabilistic Graphical Models》,1200多页的巨著微笑




public class HMM {    public static void main(String[] args) {        double[][] trans = {{0.01,0.5,0.49},                            {0.18,0.01,0.81},                            {0.7,0.29,0.01}};        double[][] mis = {{0.1,0.8,0.1},                {0.3,0.3,0.4},                {0.2,0.1,0.7}};        double[] init = {0.1,0.5,0.4};        int[] view = {2,1,2,0,2,1,2,1,0,0,0,1};        int[] res = HMM.decode(view,trans,mis,init);        for(int i: res){            System.out.print(i+"  ");        }        System.out.println("");        double r = HMM.forward(view,trans,mis,init);        System.out.println(r);    }    public static int argmax(double[] input ){        double local_max = -100000000.0;        int index = 0;        for(int i=0;i<input.length;++i){            if(input[i]>local_max){                local_max = input[i];                index = i;            }        }        return index;    }    public static int[] decode(int[] view, double[][] trans, double[][] mis, double[] init) {        int stats_num = trans.length;        int view_num = view.length;        int loop_head = 0;        int[] result = new int[view.length];        double[] tmp = new double[stats_num];        node[] net = new node[stats_num*view_num];        // init        for(int i=0;i<stats_num;++i){            net[i] = new node(i, Math.log(init[i])+Math.log(mis[i][view[0]]), null);        }        // build the net        for(int i=1;i<view_num;++i){            for(int j=0;j<stats_num;++j){                for(int k=0;k<stats_num;++k){                    tmp[k] = net[loop_head+k].log_p + Math.log(trans[net[loop_head+k].name][k]) + Math.log(mis[j][view[i]]);                }                int index = HMM.argmax(tmp);                net[loop_head+stats_num+j] = new node(j,tmp[index],net[loop_head+index]);            }            loop_head += stats_num;        }        // find the best end        int index = loop_head;        double  local_max = -10000000.0;        for(int i = loop_head;i<loop_head+stats_num;++i){            if(net[i].log_p>local_max){                local_max = net[i].log_p;                index = i;            }        }        // fill the result along the line        node end = net[index];        for(int i=view.length-1;i>=0;--i){            result[i] = end.name;            //System.out.println(end.log_p);            end = end.back;        }        // return the result        return result;    }    public static double forward(int[] view, double[][] trans, double[][] mis, double[] init) {        int stats_num = trans.length;        int view_num = view.length;        int loop_head = 0;        double tmp = 0;        node[] net = new node[stats_num*view_num];        for(int i=0;i<stats_num;++i){            net[i] = new node(i, Math.log(init[i])+Math.log(mis[i][view[0]]), null);        }        for(int i=1;i<view_num;++i){            for(int j=0;j<stats_num;++j){                tmp = 0;                for(int k=0;k<stats_num;++k){                    tmp += Math.exp(net[loop_head+k].log_p + Math.log(trans[net[loop_head+k].name][k]) + Math.log(mis[j][view[i]]));                }                tmp = Math.log(tmp);                net[loop_head+stats_num+j] = new node(j,tmp,null);            }            loop_head += stats_num;        }        tmp = 0;        for(int i = loop_head;i<loop_head+stats_num;++i){            tmp += Math.exp(net[i].log_p);        }        return tmp;    }}class node {    public node(int name, double log_p, node back) {        this.name = name;        this.log_p = log_p;        this.back = back;    }    int name;    node back;    double log_p ;}

0 0