y_uti のブログ

統計、機械学習、自然言語処理などに興味を持つエンジニアの技術ブログです

混合正規分布のパラメータ推定

『続・わかりやすいパターン認識』第 9 章のアルゴリズムを実装して、混合正規分布のパラメータ推定実験を試してみます。今回も MATLAB で実装しました。

サンプルデータの生成

まず、教科書 9.5 節にしたがってパラメータを設定します。各クラスの事前確率は教科書では π で記されていますが、MATLAB では pi が円周率を表す定数として定義されているので、第 8 章の記法にならって rho で表すことにしました。

>> mu = [3, -1];

>> sigma = sqrt([1, 1]);

>> rho = [0.6, 0.4];

このパラメータにしたがう乱数を生成します。次のように関数として実装しました。引数のうち mu, sigma, rho はモデルパラメータです。n はデータ数を指定します。関数の戻り値は s がクラス、x が観測結果です。

function [s, x] = GenerateSample(mu, sigma, rho, n)
    s = sample(rho, n);
    x = normrnd(mu(s), sigma(s));
end

function i = sample(p, n)
    i = 1 + sum(bsxfun(@gt, rand(n, 1), cumsum(p) / sum(p)), 2)';
end

この関数を実行して 500 個のデータを生成します。先頭の 10 個のデータを表示させてみると、s = 1 では x が正の数、s = 2 では x が負の数になっています。パラメータ mu を [3, -1] と設定していたので、これは妥当な結果に見えます。

>> [s, x] = GenerateSample(mu, sigma, rho, 500);

>> s(1:10)
ans =
     2     2     1     2     2     1     1     1     2     2

>> x(1:10)
ans =
   -0.3989   -0.9077    4.7298   -1.6086   -1.7371    1.2501    3.9105    3.8671   -1.0799   -0.1015

得られたデータの度数分布を描画してみます。最小値と最大値を調べて適当に描画範囲を決定し、histogram 関数で描画します。教科書のように、ふたつの正規分布の混合になっている様子がわかります。

>> [min(x) max(x)]
ans =
   -3.2033    6.5267

>> histogram(x, [-4:0.25:7] + 0.125);


パラメータ推定アルゴリズムの実装

教科書 180-181 ページの「混合正規分布のパラメータ推定」アルゴリズムを実装します。引数には、データ x, 各パラメータの初期値 mu1, sigma1, rho1, 最大反復回数 maxiter と収束判定の閾値 epsilon を与えます。戻り値は推定された mu, sigma, rho と対数尤度 logLH です。戻り値に関しては、反復の各回での値をすべて戻すようにしました*1。calcProbabilities 関数は、教科書 (9.56) 式の分子を各 i, k について計算します。これは Step 2 のほかに対数尤度の計算にも利用されます。

function [mu, sigma, rho, logLH] = Estimate(x, mu1, sigma1, rho1, maxiter, epsilon)
    % Step 1
    rho(1, :) = rho1;
    mu(1, :) = mu1;
    sigma(1, :) = sigma1;
    probs = calcProbabilities(x, mu(1, :), sigma(1, :), rho(1, :));
    logLH(1) = sum(log(sum(probs, 1)), 2);
    for j = 2:maxiter
        % Step 2
        P = bsxfun(@rdivide, probs, sum(probs, 1));
        % Step 3
        rho(j, :) = sum(P, 2)' ./ length(x);
        sigma(j, :) = sqrt(sum(P .* bsxfun(@minus, x, mu(j-1, :)') .^ 2, 2) ./ sum(P, 2))';
        mu(j, :) = sum(bsxfun(@times, P, x), 2)' ./ sum(P, 2)';
        % Step 4
        probs = calcProbabilities(x, mu(j, :), sigma(j, :), rho(j, :));
        logLH(j) = sum(log(sum(probs, 1)), 2);
        if logLH(j) - logLH(j-1) < epsilon
            break;
        end
    end
end

function p = calcProbabilities(x, mu, sigma, rho)
    c = length(mu);
    n = length(x);
    p = bsxfun(@times, rho', normpdf(repmat(x, c, 1), repmat(mu', 1, n), repmat(sigma', 1, n)));
end
パラメータ推定実験

実装したプログラムを用いて、パラメータ推定を実行します。教科書 9.5 節では正規分布の平均 μ のみを推定対象としていますが、今回の実験では μ のほかに標準偏差 σ や事前分布 π も推定します。初期値は次のように設定しました。mu1 は教科書にならって設定し、sigma1 と rho1 は正しい値を初期値にしています。

>> mu1 = [-2, -3];

>> sigma1 = sqrt([1, 1]);

>> rho1 = [0.6 0.4];

最大反復回数を 1,000 回、収束判定の条件を 1e-3 として、実装した関数を呼び出します。結果は以下のとおりでした。それぞれ真のモデルパラメータに近い値が推定されています。反復回数は 27 回でした*2

>> [Emu, Esigma, Erho, logLH] = Estimate(x, mu1, sigma1, rho1, 1000, 1e-3);

>> Emu(end,:)
ans =
    2.8468   -1.0159

>> Esigma(end,:)
ans =
    1.0335    0.8884

>> Erho(end,:)
ans =
    0.6192    0.3808

>> logLH(end)
ans =
  -1.0036e+03

>> length(logLH) - 1
ans =
    27

各パラメータの収束の様子を示します。左のグラフが μ、右のグラフが σ で、それぞれ青色の系列が ω1、赤色の系列が ω2 に対応します。

クラスの事前確率 π は次のとおりです。こちらも青色の系列が P(ω1)、赤色の系列が P(ω2) を表しますが、こちらは当然、P(ω1) + P(ω2) = 1 になります。

対数尤度は次のとおりです。ただしパラメータの初期値での対数尤度 logPH(1) は -4646 という値で極端に小さいためグラフの見やすさのために除外し、logPH(2) 以降を描画しています。

MATLAB の fitgmdist 関数を試す

[2015-06-11 追記]

MATLAB の Statistics Toolbox には、混合ガウス分布のパラメータ推定を行う fitgmdist 関数が用意されています。この関数を利用して、結果を比較してみます。MATLAB の公式ドキュメントは以下にあります。
混合ガウス分布 - MATLAB & Simulink - MathWorks 日本

次のように関数を実行します。データ x は列ベクトルで渡す必要があるようです。実行結果は gmm というオブジェクトに格納されます。実行例に表示されているように、真のモデルパラメータに近い値が推定されました。

>> gmm = fitgmdist(x', 2)
gmm =
1 次元に 2 要素をもつ混合ガウス分布
要素 1
混合比:  0.618626
平均:    2.8488
要素 2
混合比:  0.381374
平均:   -1.0134

gmm は、gmdistribution クラスのインスタンスです。このクラスの説明は以下にあります。
混合ガウス モデル - MATLAB - MathWorks 日本

各フィールドの値を適当に表示させてみます。

>> gmm.ComponentProportion
ans =
    0.6186    0.3814

>> gmm.mu
ans =
    2.8488
   -1.0134

>> squeeze(gmm.Sigma)
ans =
    1.0646
    0.7931

>> gmm.NumIterations
ans =
    34

>> gmm.NegativeLogLikelihood
ans =
   1.0036e+03

得られた混合ガウス分布から、次のようにサンプルを生成できます。コード例では、生成したサンプルの度数分布を表示しています。

>> x2 = gmm.random(10000);

>> [min(x2) max(x2)]
ans =
   -4.0141    6.2175

>> histogram(x2, [-5.5:0.5:7] + 0.25);

その他にもさまざまなメソッドが用意されています。以下では pdf メソッドを用いて確率密度関数を描画します。

>> plot(-5:0.1:7, gmm.pdf([-5:0.1:7]'));

*1:収束の様子を確認するために、このような実装にしています。パラメータ推定本来の目的には最後の値だけを戻せば十分です。

*2:logLH(1) は、初期値として与えたパラメータで反復処理の開始前に計算された対数尤度です。教科書の Step 2 から Step 4 の反復処理を実行した回数は (length(logLH) - 1) 回です。