続パタ 2.3 節「ベイズ更新の実験」を試す
『続・わかりやすいパターン認識』の 2.3 節「ベイズ更新の実験」を MATLAB で試してみました。読書会に参加しながら読み進めているのですが、前半部分は残念ながら参加できなかったので、復習の意味で読み直しています。
教科書 2.3 節の実験はコイン投げを題材としたものです。まず、教科書を引用しながら実験内容を説明します。
箱の中に、外見上はまったく区別がつかない 3 種のコイン ω1, ω2, ω3 が大量に混ぜ合わされて入っており、その含有率はそれぞれ、π1, π2, π3 とする。また、これら 3 種のコインを投げて表の出る確率はそれぞれ、θ1, θ2, θ3 とする。
(教科書 25 ページ 例題 2.1 より)
それぞれの確率を次のように設定します (教科書 25 ページ 例題 2.1 (2) より)。
コイン (ωi) | 含有率 (πi) | 表の出る確率 (θi) |
---|---|---|
ω1 | 0.1 | 0.8 |
ω2 | 0.4 | 0.6 |
ω3 | 0.5 | 0.3 |
ここで、箱から一枚のコインを取り出して、そのコインを繰り返し投げたところ、最初の 10 回の結果が次のようになったとします (教科書 33 ページ 式 (2.36) より)。結果は H が表 (Head), T が裏 (Tail) を表します。
回数 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
結果 | H | H | H | H | T | H | H | T | H | T |
コインを投げるたびに、それまでに得られた結果を条件として、取り出したコインが ω1, ω2, ω3 のいずれであったかを推定します。数式では教科書 33 ページの式 (2.34) で求まります。これを MATLAB で計算してみます。
p = [0.1 0.4 0.5]; % パラメータ π t = [0.8 0.6 0.3]; % パラメータ θ x = [1 1 1 1 0 1 1 0 1 0]; % 観測結果 (表 = 1, 裏 = 0) cx = [0 cumsum(x)]; % n 回目の観測までで表の出た回数 (n = 0 を含める) np = length(p); % コインの種類の数 (計算用, np = 3) nx = length(x); % 観測回数 (計算用, nx = 10) % 事後確率を計算: M の i 行 j 列は (i-1) 回目までの観測のもとでコインが ωj である確率を表す M = bsxfun(@times, p, binopdf(repmat(cx', 1, np), repmat((0:nx)', 1, np), repmat(t, nx + 1, 1))); M = bsxfun(@rdivide, M, sum(M, 2)); % 結果をプロット plot(0:nx, M); xlim([0 nx]); ylim([0 1]);
実行結果のグラフは次のとおりです。青、赤、黄がそれぞれ ω1, ω2, ω3 の確率を表します。教科書 34 ページ 図 2.3 の n = 10 までの結果と一致しています。
続いて、教科書の図 2.3 と同様に n = 100 までの結果を描いてみます。コインが ω1 であるとして、確率 θ1 = 0.8 で表になるような観測データを生成します。先ほどのプログラムにコードを追加します。
... x = [1 1 1 1 0 1 1 0 1 0]; % 観測結果 (表 = 1, 裏 = 0) % n = 100 までの観測データを x に追加する n = 100; x(end+1:n) = rand(1, n - length(x)) < t(1); ...
実行結果は次のとおりです。今回の実験では、教科書の図 2.3 のようには収束してくれませんでした。コイン ω1, ω2 で表の出る確率が θ1 = 0.8, θ2 = 0.6 と比較的近いので、生成されたデータによっては、このような結果になります。
次のグラフは、n 回目の観測までで表の出た割合を合わせてプロットしたものです。θ1 = 0.8 ですが、生成されたデータでは偶然に表の割合が小さかったため、ω2 の事後確率が高くなっている様子が分かります。
なお、このグラフは subplot コマンドを用いて次のコードで作成しました。
... % 結果をプロット (前半は事後確率のグラフ, 後半は表の出た割合のグラフ) subplot(2, 1, 1); plot(1:nx, M); xlim([1 nx]); ylim([0 1]); subplot(2, 1, 2); plot(1:nx, cx ./ (1:nx)); xlim([1 nx]); ylim([0.6 1]);
観測回数を n = 500 として 100 系列の観測データを生成し、各系列について事後確率をプロットした様子をアニメーションさせてみます。結果は次のようになりました。n = 100 程度では時々大きくずれることがあるようですが、n が大きくなるにつれて、ばらつきが少なくなり ω1 の事後確率が高くなっていきます。
なお、上記のアニメーション GIF は MATLAB の imwrite コマンドで作成しています。公式ドキュメントに記載があります。
イメージをグラフィックス ファイルに書き込む - MATLAB imwrite - MathWorks 日本