先月から、Andrew Ng 先生による Coursera の機械学習のコースを受講しています。各週 1 コマの分量になるように構成されていて、動画による説明を聴き、その週の範囲に関する宿題を提出しながら進めていく形式です。現在、第 3 週のロジスティック回帰まで進んだところです。そこで今回は、ロジスティック回帰について MATLAB でグラフを描きながら遊んでみます。なお、第 3 週の後半では正則化の話があるようですが、実はまだ前半しか終えていないので、今回の記事では正則化については考えません。
www.coursera.org
ロジスティック回帰は、判別問題に適用される手法です。 と の関係を次の式で予測します*1。 は、ある で となる確率を予測する関数で、 のときに と、そうでないときに と判別します。
による予測がどのくらい当たるかは、パラメータ の値に依存します。 組のデータ が与えられたとき、 による予測の「悪さ」をコスト関数 で評価できます。そこで、共役勾配法などの手法を用いて、 を最小にするパラメータ を求めます。
まず、 がどのような形をしているのか、グラフを描いて様子を見てみます。 を計算する関数を次のように実装しました。
function value = predict(theta, X) value = sigmoid(X * theta); end function y = sigmoid(x) y = 1 ./ (1 + exp(-x)); end
この関数を使って、次のプログラムでグラフを描画します。 に固定して、 の 4 通りのグラフを描いてみました。
x = (-10:0.1:10)'; figure; title('\theta_1 を変えたときの h_\theta(x) のグラフ'); xlabel('x'); ylabel('h_\theta(x)'); legends = {}; hold on; for t1 = -2:2:4 plot(x, predict([t1; 1], [ones(length(x), 1) x])); legends{end+1} = sprintf('\\theta = (%d, %d)', t1, 1); end legend(legends);
実行結果は次のとおりです。 を固定して を変えると、グラフは傾きを変えずに左右に動きます。実際 の式を考えると、 が識別境界 になることがわかります。
同様に、 を変えたときのグラフを以下に示します。左は の場合、右は の場合です。 の値によってグラフの傾きが変わる様子がわかります。 でグラフは水平になり、 では右下がりのグラフになります。
それでは次に、コスト関数 を実装して、適当なデータに対する最適な を求めてみます。コスト関数は次のように実装できます。
function value = cost(theta, X, y) h = predict(theta, X); value = -mean(y .* log(h) + (1 - y) .* log(1 - h)); end
次の 4 点からなる簡単なサンプルデータを考えます。
x | -1.0 | -0.1 | 0.1 | 1.0 |
---|---|---|---|---|
y | 0 | 1 | 0 | 1 |
次のコードは、このサンプルデータを作成して、いくつかの についてコストを計算してみたものです。試した 3 通りの中では が最も良いようです。
x = [-1; -0.1; 0.1; 1]; y = [ 0; 1 ; 0 ; 1]; cost([0; -1], [ones(length(x), 1) x], y) % 0.9788 cost([0; 1], [ones(length(x), 1) x], y) % 0.5288 cost([1; 1], [ones(length(x), 1) x], y) % 0.6371
制約無し最小化問題を解く fminunc 関数を利用して、最適な を求めます。先ほど実装した cost 関数は 3 引数の関数でしたが、fminunc には を受け取る関数を渡す必要があるので、無名関数でラップして渡しています*2。fminunc の第 2 引数は反復計算の初期値です。
theta = fminunc(@(t) cost(t, [ones(length(x), 1) x], y), [0; 0]);
計算の結果、 と求まりました。得られた でコストを計算すると、0.4510 になりました。次のコードで、この での を描いてみます。
xs = (-1.5:0.01:1.5)'; figure; hold on; plot(xs, predict(theta, [ones(length(xs), 1) xs])); plot(x, y, 'o'); xlabel('x'); ylabel('h_\theta(x)');
結果は次のとおりです。 がこれよりも大きくなると、両端の 2 点への当てはまりは改善されますが、中央の 2 点への当てはまりが悪くなります。逆に が小さくなると、中央の 2 点への当てはまりは改善されますが、両端の 2 点への当てはまりが悪くなります。これらのバランスを取った最適解がグラフの曲線になっています。
さて、ロジスティック回帰に用いるコスト関数 は、パラメータ に関してどのような形になっているでしょうか。次のコードで、 を動かしたときの を描画してみます。
[T1, T2] = meshgrid(-5:0.025:5, -2:0.025:8); J = arrayfun(@(i) cost([T1(i); T2(i)], [ones(length(x), 1) x], y), 1:length(T1(:))); J = reshape(J, size(T1)); s = surf(T1, T2, J); s.LineStyle = 'none'; view(90, -90); colorbar(); xlabel('\theta_1'); ylabel('\theta_2'); zlabel('J(\theta)');
結果は次のとおりです。下に凸な曲面になっており、先ほど得られた で最小になります。下に凸になっているので、適当な初期値で fminunc を実行しても必ず最適解が得られます。もしコスト関数がデコボコな形になっていると、fminunc は局所解の一つを求めることになり、それは必ずしも全体の最適解とは限りません。
Coursera の講義の中で、線形回帰で用いるコスト関数をロジスティック回帰に使うと、下に凸にならないという説明がありました。式に書くと以下のようになります。
こちらのコスト関数についても、実際どのような形になってしまうのかグラフに描いて確認してみます。このコスト関数は次のように実装できます。
function cost = costlin(theta, X, y) h = predict(theta, X); cost = mean((h - y) .^ 2) / 2; end
先ほどのプログラムで cost を呼び出していたところを costlin に変更すれば、同様にグラフを描画できます。結果は次のようになりました。歪んだ形にはなっていますが、下に凸になっているようにも見えます。これは、サンプルデータが単純すぎたことが理由かもしれません。
そこで、もう少し複雑なサンプルデータを使って確認してみます。次のように、平均の異なる正規分布から のデータを適当に発生させます。
x = normrnd([-3 * ones(10, 1); -1 * ones(5, 1); ones(10, 1); 2 * ones(5, 1)], 1); y = [zeros(10, 1); ones(5, 1); zeros(10, 1); ones(5, 1)];
生成したデータに対して最適なパラメータは となりました。これをプロットしたものが次の図です。
先ほどと同様にコスト関数を描いてみたものが次の図です。左が を用いたもの、右が を用いたものです。 は、このような複雑なデータに対しても比較的綺麗な形をしており、下に凸な曲面になっています。一方で は大きく歪んだ形をしており、 のあたりに鞍部があることがわかります*3。
コスト関数が右の図のようになっていると、fminunc に与える初期値によって異なる解が求まります。このことを確認してみます。fminunc が costlin 関数を使うようにして を計算します。初期値を (0, 0) として計算すると、(-0.6773, 0.3383) という解が得られました。下図左がこのパラメータで描画したグラフです。一方、初期値を (-3, 2) として計算すると、(-2179.7, 1279.5) という解が得られました。これが鞍部の右下側に落ちていった場合で、グラフを描画すると下図右になります。1 つのデータ群を完全に無視することで残りのデータを完全に説明できた形になっており、これはこれで面白い結果ですね。
さて、ここで得られた 2 つの解について でコストを計算してみると、左は 0.1006、右は 0.0833 となり、「まともな」結果に見える左の方が実は局所解で、右側の方がコストの小さな解になっていることがわかります。これはコスト関数の性質によるもので、 では、あるデータに対する予測を完全に間違えても ( となっても)、それによるコストの増加は高々 にしかならないので、このような結果になります。 では、そのような完全な間違いに対して支払うコストは無限大になるので、右図のようなパラメータに対する は大きな値になります*4。