y_uti のブログ

統計、機械学習、自然言語処理などに興味を持つエンジニアの技術ブログです

MATLAB で Baum-Welch アルゴリズムを実装する

前回に続いて、『続・わかりやすいパターン認識』のアルゴリズムを実装します。今回は Baum-Welch アルゴリズムを実装して、隠れマルコフモデルのパラメータ推定を試してみます。

前向きアルゴリズム

まず、前向きアルゴリズムを実装します。以下のように実装できます。引数は、A, B, rho がそれぞれ遷移確率、出力確率、初期状態、x が出力記号系列です。戻り値の a は教科書の (8.8) 式で定義される α です。ただし、アンダーフローを避けるために α そのものではなく log(α) を計算します。添字は、行方向を i, 列方向を t としています。

function a = Forward(A, B, rho, x)
    % Step 1  初期化
    a(:, 1) = log(rho') + log(B(:, x(1)));
    % Step 2  再帰的計算
    for t = 2:length(x)
        c = max(a(:, t-1));
        a(:, t) = log(((exp(a(:, t-1) - c))' * A))' + log(B(:, x(t))) + c;
    end
end

Step 2 の再帰的計算では、logsumexp と呼ばれる計算技法を利用しています。この教科書では logsumexp には触れられていないようですが、ウェブで検索すれば情報が見つかります。書籍では、たとえば『言語処理のための機械学習入門』の付録 A.2 に説明があります。

前向きアルゴリズムの動作を確認します。教科書の (8.116) 式、(8.117) 式のとおりにパラメータ A, B, rho を定義します。出力記号系列 x も (8.118) 式のとおりに定めます。

>> A = [
  0.1 0.7 0.2
  0.2 0.1 0.7
  0.7 0.2 0.1
];

>> B = [
  0.9 0.1
  0.6 0.4
  0.1 0.9
];

>> rho = [1 1 1] / 3;

>> x = [1 2 1];

Forward 関数を実行すると、以下のように log(α) が求まります。これを α に戻して値を確認すると、それぞれ教科書の値に一致していることがわかります*1。教科書 (8.13) 式にあるように、最後の時刻での i に関する和が P(x) になります。この例では 0.1734 と求まり、たしかに (8.120) 式の値と一致します。

>> alpha = Forward(A, B, rho, x)
alpha =
   -1.2040   -4.6742   -2.0161
   -1.6094   -2.3574   -3.4559
   -3.4012   -1.6983   -4.7510

>> exp(alpha)
ans =
    0.3000    0.0093    0.1332
    0.2000    0.0947    0.0316
    0.0333    0.1830    0.0086

>> sum(exp(alpha(:, end)))
ans =
    0.1734
後ろ向きアルゴリズム

後ろ向きアルゴリズムも前向きアルゴリズムと同様に実装できます。以下のようになります。

function b = Backward(A, B, rho, x)
    % Step 1  初期化
    b(:, length(x)) = log(ones(length(rho), 1));
    % Step 2  再帰的計算
    for t = length(x)-1:-1:1
        c = max(b(:, t+1));
        b(:, t) = log(A * exp(log(B(:, x(t+1))) + b(:, t+1) - c)) + c;
    end
end

後ろ向きアルゴリズムの動作を確認します。前向きアルゴリズムで利用したものと同じデータを用いて関数を実行すると、以下のように β が求まります。

>> beta = Backward(A, B, rho, x)
beta =
   -1.4745   -0.6349         0
   -0.6896   -1.1712         0
   -2.0379   -0.2744         0

β の定義 (8.9) から、P(x) は時刻 t = 1 での β を用いて次のように計算できます。

P(x) = Σ_i{P(s1=wi) * P(x1|s1=wi) * β(t=1, s1=wi)}

今回の例では 0.1734 と求まり、α から求めた場合と一致します。

>> sum(rho' .* B(:, x(1)) .* exp(beta(:, 1)))
ans =
    0.1734
Baum-Welch アルゴリズム

Baum-Welch アルゴリズムは次のように実装しました。引数は出力記号系列 x と状態数 nstate です。教科書 8.6 節では A, B, rho の初期値を固定していますが、私の実装では潜在状態数を指定してランダムな値を割り当てるようにしました。関数の戻り値は、推定されたパラメータ A, B, rho と、各反復での対数尤度 logP(x) です*2

function [A, B, rho, logLH] = BaumWelch(x, nstate)
    maxiter = 100;   % 最大反復回数を 100 回とする
    epsilon = 1e-3;  % 対数尤度の増加が 1e-3 未満なら収束したとみなして終了する 
    % Step 1  初期化
    [A, B, rho] = initialize(nstate, max(x));
    % Step 2  再帰的計算
    a = Forward(A, B, rho, x);
    b = Backward(A, B, rho, x);
    % Step 4  判定 (初期パラメータでの対数尤度を計算する)
    logLH(1) = calcLikelihood(a);
    for i = 2:maxiter
        % Step 3  パラメータの更新
        [A, B, rho] = maximize(A, B, x, a, b);
        % Step 2  再帰的計算
        a = Forward(A, B, rho, x);
        b = Backward(A, B, rho, x);
        % Step 4  判定
        logLH(i) = calcLikelihood(a);
        if logLH(i) - logLH(i-1) < epsilon
            break;
        end
    end
end

function [A, B, rho] = initialize(c, m)
    A   = normalize(rand(c));
    B   = normalize(rand(c, m));
    rho = normalize(rand(1, c));
end

function [A, B, rho] = maximize(A, B, x, a, b)
    g  = calcGamma(a, b);
    xi = calcXi(A, B, x, a, b);
    A  = normalize(sum(exp(bsxfun(@minus, xi, max(max(xi, [], 3), [], 2))), 3));
    for k = 1:size(B, 2)
        B(:, k) = sum(exp(bsxfun(@minus, g(:, find(x == k)), max(g, [], 2))), 2);
    end
    B = normalize(B);
    rho = normalize(exp(g(:, 1) - max(g(:, 1)))');
end

function g = calcGamma(a, b)
    g = a + b;
end

function xi = calcXi(A, B, x, a, b)
    c  = size(A, 1);
    t  = length(x) - 1;
    xi = repmat(permute(a(:, 1:end-1)      , [1 3 2]), [1 c 1]) ...
       + repmat(permute(log(A)             , [1 2 3]), [1 1 t]) ...
       + repmat(permute(log(B(:, x(2:end))), [3 1 2]), [c 1 1]) ...
       + repmat(permute(b(:, 2:end)        , [3 1 2]), [c 1 1]);
end

function l = calcLikelihood(a)
    c = max(a(:, end));
    l = log(sum(exp(a(:, end) - c))) + c;
end

function M = normalize(M)
    M = bsxfun(@rdivide, M, sum(M, 2));
end

教科書のアルゴリズム説明では、Step 3 でパラメータの更新を行った後、Step 4 の判定で対数尤度を計算しています。ところが、対数尤度を計算するには α か β を求める必要があり、それは Step 2 の再帰的計算の処理に相当します。そのため、上述の実装例のように処理の順番を入れ替えました。

calcGamma 関数、calcXi 関数では、(8.20) 式の γ, (8.43) 式の ξ を計算します。ただし、アンダーフローを避けるために対数での計算としており、また、各式の分母は計算しません。これらの式の分母は和を 1 にするための係数です。したがって実装上は、分子を計算してから総和を求めて割れば、同じ結果になります。normalize 関数でこれを実現しています。

Baum-Welch アルゴリズムの動作確認

Baum-Welch アルゴリズムによるパラメータ推定を確認します。最初の例として、1, 2, 3 を繰り返す系列のパラメータを推定します。状態数を 2 とした推定は以下のようになりました。行列 B を見ると、状態 1, 2 のいずれの場合でも出力記号 1, 2, 3 がほぼ等確率で出力されます。あまり上手く推定できていないようです。

>> x = repmat([1 2 3], 1, 100);

>> [A, B, rho, logLH] = BaumWelch(x, 2)
A =
    0.8570    0.1430
    0.7402    0.2598
B =
    0.3383    0.3305    0.3312
    0.3077    0.3480    0.3443
rho =
    0.9874    0.0126
logLH =
 -365.7825 -332.0131 -330.6177 -330.1137 -329.8861 -329.7684 -329.7018 -329.6614 -329.6358 -329.6190 -329.6076 -329.5998 -329.5943 -329.5903 -329.5875 -329.5854 -329.5839 -329.5828 -329.5820

一方、同じ例で状態数を 3 にすると次のようになりました。こちらは、現在の状態によって出力記号が一意に決まり、次の状態への遷移も一意に決まるという推定結果になりました。出力状態系列は 1, 2, 3, 1, 2, 3, ... と繰り返すものでしたので、この結果は妥当なものだと思います。

>> [A, B, rho, logLH] = BaumWelch(x, 3)
A =
    0.0000    1.0000    0.0000
    0.0000    0.0000    1.0000
    1.0000    0.0000    0.0000
B =
    0.0000    0.0000    1.0000
    1.0000    0.0000    0.0000
    0.0000    1.0000    0.0000
rho =
    0.0000    1.0000    0.0000
logLH =
 -322.1001 -265.6989  -98.4127   -3.6947   -0.0005   -0.0000

次の例として、x のあとに x 自身を反転させたものを連結します。1, 2, 3, 1, 2, 3, ... を繰り返したのち、..., 3, 2, 1, 3, 2, 1, ... と逆回りをはじめる出力記号系列です。先ほどと同様に 3 状態で推定させてみた結果が以下です。状態 1 のときに 2 を出力し、状態が 2, 3 のときには 1, 3 のいずれかを出力するモデルが得られました。

>> x = [x x(end:-1:1)];

>> [A, B, rho, logLH] = BaumWelch(x, 3)
A =
         0    1.0000    0.0000
    0.0000    0.0000    1.0000
    1.0000    0.0000    0.0000
B =
    0.0000    1.0000    0.0000
    0.5000    0.0000    0.5000
    0.5000    0.0000    0.5000
rho =
         0    0.0000    1.0000
logLH =
 -730.4094 -652.2383 -644.0695 -633.4681 -618.4745 -597.6392 -574.0333 -558.5107 -554.4464 -553.7989 -553.3115 -552.1486 -548.6988 -538.3057 -508.0520 -434.2421 -331.4732 -282.3566 -277.3188 -277.2591 -277.2589

状態数を 6 にすると次の結果が得られました。状態 4 から始まって、4, 1, 2, 4, 1, 2, ... の状態を繰り返します。B を見ると、出力記号系列は 1, 2, 3, 1, 2, 3, ... になることが分かります。状態 2 からは 1% の確率で状態 3 に遷移します。状態 3 からは 3, 5, 6, 3, 5, 6, ... を繰り返し、対応する出力記号系列は 3, 2, 1, 3, 2, 1, ... となります。

>> [A, B, rho, logLH] = BaumWelch(x, 6)
A =
    0.0000    1.0000    0.0000    0.0000         0    0.0000
    0.0000    0.0000    0.0100    0.9900    0.0000    0.0000
    0.0000    0.0000    0.0000    0.0000    1.0000    0.0000
    1.0000    0.0000    0.0000    0.0000    0.0000    0.0000
    0.0000    0.0000    0.0000    0.0000    0.0000    1.0000
    0.0000    0.0000    1.0000    0.0000    0.0000    0.0000
B =
    0.0000    1.0000    0.0000
    0.0000    0.0000    1.0000
    0.0000    0.0000    1.0000
    1.0000    0.0000    0.0000
    0.0000    1.0000    0.0000
    1.0000    0.0000    0.0000
rho =
         0         0         0    1.0000    0.0000    0.0000
logLH =
 -680.2880 -647.1606 -636.3205 -619.0222 -581.4084 -495.4729 -346.5039 -200.9830 -116.3330  -71.1324 -33.2078   -9.9379   -5.6963   -5.6002   -5.6002

[以下 2015-06-05 追記]

最後の例として、教科書の 8.6 節と同じ設定でパラメータを推定してみます。まず、(8.116), (8.144) のとおりに真のパラメータを定めます。

>> A = [
  0.1 0.7 0.2
  0.2 0.1 0.7
  0.7 0.2 0.1
];

>> B = [
  0.9 0.1
  0.6 0.4
  0.1 0.9
];

>> rho = [1 0 0];

前回の記事で実装した GenerateSample 関数を使って、観測回数 n = 10000 の出力記号系列を生成します。なお、同時に状態系列も生成されますが、これは今回の実験では使いません。

>> [s, x] = GenerateSample(A, B, rho, 10000);

教科書の (8.145), (8.146) のとおりに推定の初期値を定めます。それぞれ A1, B1, rho1 としました。

>> A1 = [
    0.15 0.60 0.25
    0.25 0.15 0.60
    0.60 0.25 0.15
];
>> B1 = ones(3, 2) .* 0.5;
>> rho1 = [1 0 0];

今回作成したプログラムでは推定の初期値をランダムに生成していましたので、初期値を引数として渡せるように微修正します。修正後のコードは次のようになります。教科書の記述に「ほぼ 150 回で収束した」とあるので、終了条件も変更しました。

% function [A, B, rho, logLH] = BaumWelch(x, nstate)
function [A, B, rho, logLH] = BaumWelch(x, A, B, rho)
    maxiter = 1000;  % 100 から 1000 に変更しました
    epsilon = 1e-4;  % 1e-3 から 1e-4 に変更しました
    % [A, B, rho] = initialize(nstate, max(x));
    a = Forward(A, B, rho, x);
    b = Backward(A, B, rho, x);

...

パラメータ推定の結果は以下のとおりでした。パラメータ A は真の値に近い推定結果が得られていますが、B の方は行を入れ替えたような推定結果になっています。このようになる原因はよくわかりませんでした。

>> [Ae, Be, rhoe, logLH] = BaumWelch(x, A1, B1, rho1);

>> Ae
Ae =
    0.0632    0.7522    0.1846
    0.2198    0.1260    0.6542
    0.7467    0.1518    0.1015

>> Be
Be =
    0.1244    0.8756
    0.8749    0.1251
    0.5959    0.4041

>> rhoe
rhoe =
     1     0     0

この推定実験の各反復での対数尤度は、次のグラフのとおりです。これも、教科書の図 8.5 では -3000 程度からはじまり -2930 あたりで収束しているのに対して、今回の実験では -6900 から -6750 程度と低い値になっています。
f:id:y_uti:20150605075327p:plain

真のパラメータと推定されたパラメータのそれぞれで logP(x) を計算してみた結果は以下のとおりです。

>> a = Forward(A, B, rho, x);
>> log(sum(exp(a(:,end) - max(a(:, end))))) + max(a(:,end))
ans =
  -6.7482e+03

>> ae = Forward(Ae, Be, rhoe, x);
>> log(sum(exp(ae(:,end) - max(ae(:, end))))) + max(ae(:,end))
ans =
  -6.7458e+03

また、それぞれのパラメータのもとで状態系列の最尤推定を行い条件付き確率 logP(x|s) を計算すると、次のようになりました。

>> s_ml = Viterbi(A, B, rho, x);
>> sum(log(B(sub2ind(size(B), s_ml, x))))
ans =
  -3.5488e+03

>> se_ml = Viterbi(Ae, Be, rhoe, x);
>> sum(log(Be(sub2ind(size(Be), se_ml, x))))
ans =
  -3.9603e+03

% 出力記号系列を生成したときの状態系列のもとでの条件付き確率 (比較として)
>> sum(log(B(sub2ind(size(B), s, x))))
ans =
  -4.3743e+03

*1:教科書 8.6 節の (8.122), (8.124), (8.125), (8.127), (8.128) と比較してください。

*2:本来 logLH を戻す必要はないのですが、確認のため戻り値に加えています。