y_uti のブログ

統計、機械学習、自然言語処理などに興味を持つエンジニアの技術ブログです

続パタ 2.3 節「ベイズ更新の実験」を試す

『続・わかりやすいパターン認識』の 2.3 節「ベイズ更新の実験」を MATLAB で試してみました。読書会に参加しながら読み進めているのですが、前半部分は残念ながら参加できなかったので、復習の意味で読み直しています。

教科書 2.3 節の実験はコイン投げを題材としたものです。まず、教科書を引用しながら実験内容を説明します。

箱の中に、外見上はまったく区別がつかない 3 種のコイン ω1, ω2, ω3 が大量に混ぜ合わされて入っており、その含有率はそれぞれ、π1, π2, π3 とする。また、これら 3 種のコインを投げて表の出る確率はそれぞれ、θ1, θ2, θ3 とする。
(教科書 25 ページ 例題 2.1 より)

それぞれの確率を次のように設定します (教科書 25 ページ 例題 2.1 (2) より)。

コイン (ωi) 含有率 (πi) 表の出る確率 (θi)
ω1 0.1 0.8
ω2 0.4 0.6
ω3 0.5 0.3

ここで、箱から一枚のコインを取り出して、そのコインを繰り返し投げたところ、最初の 10 回の結果が次のようになったとします (教科書 33 ページ 式 (2.36) より)。結果は H が表 (Head), T が裏 (Tail) を表します。

回数 1 2 3 4 5 6 7 8 9 10
結果 H H H H T H H T H T

コインを投げるたびに、それまでに得られた結果を条件として、取り出したコインが ω1, ω2, ω3 のいずれであったかを推定します。数式では教科書 33 ページの式 (2.34) で求まります。これを MATLAB で計算してみます。

p = [0.1 0.4 0.5];         % パラメータ π
t = [0.8 0.6 0.3];         % パラメータ θ
x = [1 1 1 1 0 1 1 0 1 0]; % 観測結果 (表 = 1, 裏 = 0)

cx = [0 cumsum(x)];        % n 回目の観測までで表の出た回数 (n = 0 を含める)
np = length(p);            % コインの種類の数 (計算用, np = 3)
nx = length(x);            % 観測回数 (計算用, nx = 10)

% 事後確率を計算: M の i 行 j 列は (i-1) 回目までの観測のもとでコインが ωj である確率を表す
M = bsxfun(@times, p, binopdf(repmat(cx', 1, np), repmat((0:nx)', 1, np), repmat(t, nx + 1, 1)));
M = bsxfun(@rdivide, M, sum(M, 2));

% 結果をプロット
plot(0:nx, M);
xlim([0 nx]);
ylim([0 1]);

実行結果のグラフは次のとおりです。青、赤、黄がそれぞれ ω1, ω2, ω3 の確率を表します。教科書 34 ページ 図 2.3 の n = 10 までの結果と一致しています。
f:id:y_uti:20151018000346p:plain

続いて、教科書の図 2.3 と同様に n = 100 までの結果を描いてみます。コインが ω1 であるとして、確率 θ1 = 0.8 で表になるような観測データを生成します。先ほどのプログラムにコードを追加します。

  ...
x = [1 1 1 1 0 1 1 0 1 0]; % 観測結果 (表 = 1, 裏 = 0)

% n = 100 までの観測データを x に追加する
n = 100;
x(end+1:n) = rand(1, n - length(x)) < t(1);
  ...

実行結果は次のとおりです。今回の実験では、教科書の図 2.3 のようには収束してくれませんでした。コイン ω1, ω2 で表の出る確率が θ1 = 0.8, θ2 = 0.6 と比較的近いので、生成されたデータによっては、このような結果になります。
f:id:y_uti:20151018001733p:plain

次のグラフは、n 回目の観測までで表の出た割合を合わせてプロットしたものです。θ1 = 0.8 ですが、生成されたデータでは偶然に表の割合が小さかったため、ω2 の事後確率が高くなっている様子が分かります。
f:id:y_uti:20151018002112p:plain

なお、このグラフは subplot コマンドを用いて次のコードで作成しました。

  ...
% 結果をプロット (前半は事後確率のグラフ, 後半は表の出た割合のグラフ)
subplot(2, 1, 1);
plot(1:nx, M);
xlim([1 nx]);
ylim([0 1]);
subplot(2, 1, 2);
plot(1:nx, cx ./ (1:nx));
xlim([1 nx]);
ylim([0.6 1]);

観測回数を n = 500 として 100 系列の観測データを生成し、各系列について事後確率をプロットした様子をアニメーションさせてみます。結果は次のようになりました。n = 100 程度では時々大きくずれることがあるようですが、n が大きくなるにつれて、ばらつきが少なくなり ω1 の事後確率が高くなっていきます。
f:id:y_uti:20151018003609g:plain

なお、上記のアニメーション GIFMATLAB の imwrite コマンドで作成しています。公式ドキュメントに記載があります。
イメージをグラフィックス ファイルに書き込む - MATLAB imwrite - MathWorks 日本