混合正規分布のパラメータ推定
『続・わかりやすいパターン認識』第 9 章のアルゴリズムを実装して、混合正規分布のパラメータ推定実験を試してみます。今回も MATLAB で実装しました。
サンプルデータの生成
まず、教科書 9.5 節にしたがってパラメータを設定します。各クラスの事前確率は教科書では π で記されていますが、MATLAB では pi が円周率を表す定数として定義されているので、第 8 章の記法にならって rho で表すことにしました。
>> mu = [3, -1]; >> sigma = sqrt([1, 1]); >> rho = [0.6, 0.4];
このパラメータにしたがう乱数を生成します。次のように関数として実装しました。引数のうち mu, sigma, rho はモデルパラメータです。n はデータ数を指定します。関数の戻り値は s がクラス、x が観測結果です。
function [s, x] = GenerateSample(mu, sigma, rho, n) s = sample(rho, n); x = normrnd(mu(s), sigma(s)); end function i = sample(p, n) i = 1 + sum(bsxfun(@gt, rand(n, 1), cumsum(p) / sum(p)), 2)'; end
この関数を実行して 500 個のデータを生成します。先頭の 10 個のデータを表示させてみると、s = 1 では x が正の数、s = 2 では x が負の数になっています。パラメータ mu を [3, -1] と設定していたので、これは妥当な結果に見えます。
>> [s, x] = GenerateSample(mu, sigma, rho, 500); >> s(1:10) ans = 2 2 1 2 2 1 1 1 2 2 >> x(1:10) ans = -0.3989 -0.9077 4.7298 -1.6086 -1.7371 1.2501 3.9105 3.8671 -1.0799 -0.1015
得られたデータの度数分布を描画してみます。最小値と最大値を調べて適当に描画範囲を決定し、histogram 関数で描画します。教科書のように、ふたつの正規分布の混合になっている様子がわかります。
>> [min(x) max(x)] ans = -3.2033 6.5267 >> histogram(x, [-4:0.25:7] + 0.125);
パラメータ推定アルゴリズムの実装
教科書 180-181 ページの「混合正規分布のパラメータ推定」アルゴリズムを実装します。引数には、データ x, 各パラメータの初期値 mu1, sigma1, rho1, 最大反復回数 maxiter と収束判定の閾値 epsilon を与えます。戻り値は推定された mu, sigma, rho と対数尤度 logLH です。戻り値に関しては、反復の各回での値をすべて戻すようにしました*1。calcProbabilities 関数は、教科書 (9.56) 式の分子を各 i, k について計算します。これは Step 2 のほかに対数尤度の計算にも利用されます。
function [mu, sigma, rho, logLH] = Estimate(x, mu1, sigma1, rho1, maxiter, epsilon) % Step 1 rho(1, :) = rho1; mu(1, :) = mu1; sigma(1, :) = sigma1; probs = calcProbabilities(x, mu(1, :), sigma(1, :), rho(1, :)); logLH(1) = sum(log(sum(probs, 1)), 2); for j = 2:maxiter % Step 2 P = bsxfun(@rdivide, probs, sum(probs, 1)); % Step 3 rho(j, :) = sum(P, 2)' ./ length(x); sigma(j, :) = sqrt(sum(P .* bsxfun(@minus, x, mu(j-1, :)') .^ 2, 2) ./ sum(P, 2))'; mu(j, :) = sum(bsxfun(@times, P, x), 2)' ./ sum(P, 2)'; % Step 4 probs = calcProbabilities(x, mu(j, :), sigma(j, :), rho(j, :)); logLH(j) = sum(log(sum(probs, 1)), 2); if logLH(j) - logLH(j-1) < epsilon break; end end end function p = calcProbabilities(x, mu, sigma, rho) c = length(mu); n = length(x); p = bsxfun(@times, rho', normpdf(repmat(x, c, 1), repmat(mu', 1, n), repmat(sigma', 1, n))); end
パラメータ推定実験
実装したプログラムを用いて、パラメータ推定を実行します。教科書 9.5 節では正規分布の平均 μ のみを推定対象としていますが、今回の実験では μ のほかに標準偏差 σ や事前分布 π も推定します。初期値は次のように設定しました。mu1 は教科書にならって設定し、sigma1 と rho1 は正しい値を初期値にしています。
>> mu1 = [-2, -3]; >> sigma1 = sqrt([1, 1]); >> rho1 = [0.6 0.4];
最大反復回数を 1,000 回、収束判定の条件を 1e-3 として、実装した関数を呼び出します。結果は以下のとおりでした。それぞれ真のモデルパラメータに近い値が推定されています。反復回数は 27 回でした*2。
>> [Emu, Esigma, Erho, logLH] = Estimate(x, mu1, sigma1, rho1, 1000, 1e-3); >> Emu(end,:) ans = 2.8468 -1.0159 >> Esigma(end,:) ans = 1.0335 0.8884 >> Erho(end,:) ans = 0.6192 0.3808 >> logLH(end) ans = -1.0036e+03 >> length(logLH) - 1 ans = 27
各パラメータの収束の様子を示します。左のグラフが μ、右のグラフが σ で、それぞれ青色の系列が ω1、赤色の系列が ω2 に対応します。
クラスの事前確率 π は次のとおりです。こちらも青色の系列が P(ω1)、赤色の系列が P(ω2) を表しますが、こちらは当然、P(ω1) + P(ω2) = 1 になります。
対数尤度は次のとおりです。ただしパラメータの初期値での対数尤度 logPH(1) は -4646 という値で極端に小さいためグラフの見やすさのために除外し、logPH(2) 以降を描画しています。
MATLAB の fitgmdist 関数を試す
[2015-06-11 追記]
MATLAB の Statistics Toolbox には、混合ガウス分布のパラメータ推定を行う fitgmdist 関数が用意されています。この関数を利用して、結果を比較してみます。MATLAB の公式ドキュメントは以下にあります。
混合ガウス分布 - MATLAB & Simulink - MathWorks 日本
次のように関数を実行します。データ x は列ベクトルで渡す必要があるようです。実行結果は gmm というオブジェクトに格納されます。実行例に表示されているように、真のモデルパラメータに近い値が推定されました。
>> gmm = fitgmdist(x', 2) gmm = 1 次元に 2 要素をもつ混合ガウス分布 要素 1 混合比: 0.618626 平均: 2.8488 要素 2 混合比: 0.381374 平均: -1.0134
gmm は、gmdistribution クラスのインスタンスです。このクラスの説明は以下にあります。
混合ガウス モデル - MATLAB - MathWorks 日本
各フィールドの値を適当に表示させてみます。
>> gmm.ComponentProportion ans = 0.6186 0.3814 >> gmm.mu ans = 2.8488 -1.0134 >> squeeze(gmm.Sigma) ans = 1.0646 0.7931 >> gmm.NumIterations ans = 34 >> gmm.NegativeLogLikelihood ans = 1.0036e+03
得られた混合ガウス分布から、次のようにサンプルを生成できます。コード例では、生成したサンプルの度数分布を表示しています。
>> x2 = gmm.random(10000); >> [min(x2) max(x2)] ans = -4.0141 6.2175 >> histogram(x2, [-5.5:0.5:7] + 0.25);
その他にもさまざまなメソッドが用意されています。以下では pdf メソッドを用いて確率密度関数を描画します。
>> plot(-5:0.1:7, gmm.pdf([-5:0.1:7]'));