自然语言处理系列之Viterbi算法

来源:互联网 发布:淘宝客服昵称怎么修改 编辑:程序博客网 时间:2024/05/22 12:58

  前面已经介绍了隐马尔可夫模型,本篇博文主要是介绍用 viterbi 算法来解决 HMM 中的预测问题,也称为解码问题。
  维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划(dynamic programming)求概率最大路径(最优路径)。这时一条路径对应着一个状态序列。
  根据动态规划原理,最优路径具有这样的特性:如果最优路径在时刻t通过(it),那么这一路径从it到终点iT的部分路径,对于从itiT的所有可能的部分路径来说,必须是最优的。因为假如不是这样,那么从i1到终点iT就有另一条更好的部分路径存在,如果把它和i1到终点it的部分路径连接起来,就会形成一条比原来的路径更优的路径,这是矛盾的。依据这一原理,我们只需从时刻t=1开始,递推地计算在时刻t状态为i的各条部分路径的最大概率,直至得到时刻t=T状态为i的各条路径的最大概率。时刻t=T的最大概率即为最优路径的概率P,最优路径的终结点iT也同时得到。之后,为了找出最优路径的各个结点,从终结点iT开始,由后向前逐步求得结点iT1,...,i1得到最优路径这就是维特比算法。

  • viterbi 算法
    输入:模型λ=(A,B,π)和观测O=(o1,o2,...,oT);
    输出:最优路径(i1,...,iT1,iT).
    (1) 初始化

    δ1(i)=πibi(oi),i=1,2,...,N

    ψ1(i)=0,i=1,2,...,N

    (2) 递推.对t=2,3,...,T
    δt(i)=max[δt1(j)aji]bi(ot),i=1,2,..,N;1jN

    ψt(i)=argmax[δt1(j)aji],i=1,2,...,N;1jN

    (3) 终止
    P=maxδT(i),1jN

    iT=argmax[δT(i)],1jN

    (4)最优路径回溯. 对t=T1,T2,...,1
    it=ψt+1(it+1)

  • viterbi算法实现

package com.feng.nlp.algorithm;import java.util.*;/** * Created by lionel on 17/4/11. */public class Viterbi {    public static List<String> compute(String[] observe, String[] status, double[] start_p, double[][] transfer_p, double[][] observe_p) {        double[][] theta = new double[observe.length][status.length];        int[][] delta = new int[observe.length][status.length];        transfermation(start_p, transfer_p, observe_p);        for (int j = 0; j < status.length; j++) {            theta[0][j] = start_p[j] + observe_p[j][0];            delta[0][j] = 0;        }        Map<String, Integer> map = new HashMap<String, Integer>();        int index = 0;        for (String ele : observe) {            if (map.containsKey(ele)) {                continue;            }            map.put(ele, index);            index++;        }        for (int i = 1; i < observe.length; i++) {            for (int j = 0; j < status.length; j++) {                int direction = 0;                double prob = Double.MAX_VALUE;                for (int k = 0; k < status.length; k++) {                    double tmpProb = theta[i - 1][k] + transfer_p[k][j] + observe_p[j][map.get(observe[i])];                    if (tmpProb < prob) {                        prob = tmpProb;                        direction = k;                        theta[i][j] = prob;                    }                }                delta[i][j] = direction;            }        }//        for (int i = 0; i < theta.length; i++) {//            for (int j = 0; j < theta[i].length; j++) {//                System.out.print(theta[i][j] + " ");//            }//            System.out.println();//        }        double prob = Double.MAX_VALUE;        int pos = 0;        for (int j = 0; j < status.length; j++) {            if (theta[observe.length - 1][j] < prob) {                prob = theta[observe.length - 1][j];                pos = j;            }        }        List<String> res = new ArrayList<String>();        res.add(status[pos]);        //回溯路径        for (int i = observe.length - 1; i > 0; i--) {            res.add(status[delta[i][pos]]);            pos = delta[i][pos];        }        Collections.reverse(res);        return res;    }    public static void transfermation(double[] start_p, double[][] transfer_p, double[][] observe_p) {        for (int i = 0; i < start_p.length; ++i) {            start_p[i] = -Math.log(start_p[i]);        }        for (int i = 0; i < transfer_p.length; ++i) {            for (int j = 0; j < transfer_p[i].length; ++j) {                transfer_p[i][j] = -Math.log(transfer_p[i][j]);            }        }        for (int i = 0; i < observe_p.length; ++i) {            for (int j = 0; j < observe_p[i].length; ++j) {                observe_p[i][j] = -Math.log(observe_p[i][j]);            }        }    }    public static void main(String[] args) {        String[] observe = {"红", "白", "红"};        String[] status = {"1", "2", "3"};        double[] start_p = new double[]{0.2, 0.4, 0.4};        double[][] transfer_p = new double[][]{                {0.5, 0.2, 0.3},                {0.3, 0.5, 0.2},                {0.2, 0.3, 0.5}        };        double[][] observe_p = new double[][]{                {0.5, 0.5},                {0.4, 0.6},                {0.7, 0.3}        };        List<String> result = compute(observe, status, start_p, transfer_p, observe_p);        System.out.println(result);//[3, 3, 3]    }}

  测试用例来源于李航老师的《统计机器学习》的例子。

  • 参考资料:《统计机器学习》,李航
0 0
原创粉丝点击