『言語処理 100 本ノック』に PHP で挑む (問題 78 ~ 79)
『言語処理 100 本ノック』に PHP で挑戦しています。今回は、第 8 章の残りの問題を解いていきます。
www.cl.ecei.tohoku.ac.jp
78. 5分割交差検定
76-77の実験では,学習に用いた事例を評価にも用いたため,正当な評価とは言えない.すなわち,分類器が訓練事例を丸暗記する際の性能を評価しており,モデルの汎化性能を測定していない.そこで,5分割交差検定により,極性分類の正解率,適合率,再現率,F1スコアを求めよ.
まず、交差検定用にファイルを分割するプログラムを実装します。split_feature 関数は入力ファイルを行単位に $n 等分して配列に格納します。PHP には array_chunk 関数が用意されていますが、array_chunk 関数では末尾の要素だけが短くなってしまうので*1、均等に分割できるように自作しています。分割したデータは、train[0-4].txt, test[0-4].txt として一度ファイルに書き出します。
<?php function split_feature($filename, $n) { $lines = file($filename); $splitted = []; $limit = 0; for ($i = 0; $i < $n; ++$i) { $offset = $limit; $limit = (int) (count($lines) * ($i + 1) / $n); $splitted[] = array_slice($lines, $offset, $limit - $offset); } return $splitted; } function write_splitted_features($splitted) { $n = count($splitted); for ($i = 0; $i < $n; ++$i) { unlink("train$i.txt"); for ($j = 0; $j < $n; ++$j) { if ($j != $i) { file_put_contents("train$i.txt", $splitted[$j], FILE_APPEND); } } file_put_contents("test$i.txt", $splitted[$i]); } }
交差検定を行うプログラムは以下のように実装しました。ロジスティック回帰で train$i.txt を学習し、test$i.txt の極性を予測します。処理の内容は、第 8 章のこれまでの問題で実装してきたものと同様です。
<?php function train_and_predict($logisticRegression, $i) { $train = read_feature_data("train$i.txt"); train($logisticRegression, $train, "model$i.txt"); $test = read_feature_data("test$i.txt"); predict($logisticRegression, $test, "predict$i.txt"); } function read_feature_data($filename) { $lines = file($filename, FILE_IGNORE_NEW_LINES); $data = []; foreach ($lines as $line) { $data[] = decode($line); } return $data; } function decode($encoded) { $columns = explode(' ', $encoded); $label = $columns[0] == '+1' ? 1 : 0; $feature = []; for ($i = 1; $i < count($columns); $i += 2) { $feature[$columns[$i]] = (int) $columns[$i + 1]; } return [$label, $feature]; } function train($logisticRegression, $features, $filename) { $logisticRegression->train($features); $fh = fopen($filename, 'wb'); foreach ($logisticRegression->weights as $word => $value) { fwrite($fh, "$word $value\n"); } fclose($fh); } function predict($logisticRegression, $features, $filename) { $counts = [[0, 0], [0, 0]]; $fh = fopen($filename, 'wb'); foreach ($features as list($label, $feature)) { $hypothesis = $logisticRegression->hypothesis($feature); $predicted = $hypothesis > 0.5 ? 1 : 0; $line = ($label ? '+1' : '-1') . "\t" . ($predicted ? '+1' : '-1') . "\t" . $hypothesis; fwrite($fh, "$line\n"); ++$counts[$label][$predicted]; } fclose($fh); $tp = $counts[1][1]; $fp = $counts[0][1]; $fn = $counts[1][0]; $tn = $counts[0][0]; $accuracy = ($tp + $tn) / ($tp + $fp + $fn + $tn); $precision = $tp / ($tp + $fp); $recall = $tp / ($tp + $fn); $f1score = 2 * $precision * $recall / ($precision + $recall); echo "Accuracy = $accuracy\n"; echo "Precision = $precision\n"; echo "Recall = $recall\n"; echo "F1 score = $f1score\n"; }
以上の関数を利用して、次のプログラムで 5 分割交差検定を実行します。LogisticRegression.php は問題 73 で実装したものをそのまま利用します。
<?php require_once __DIR__ . '/LogisticRegression.php'; require_once __DIR__ . '/split_feature.php'; require_once __DIR__ . '/train_and_predict.php'; main(); function main() { $n = 5; $splitted = split_feature('feature.txt', $n); write_splitted_features($splitted); $logisticRegression = new LogisticRegression(1e-4, 1e-2); for ($i = 0; $i < 5; ++$i) { train_and_predict($logisticRegression, $i); } }
実行結果は以下のとおりです。5 分割交差検定のそれぞれに対して、正解率、適合率、再現率、F1 score を出力しています。
$ php main.php Accuracy = 0.74390243902439 Precision = 0.73957367933272 Recall = 0.75070555032926 F1 score = 0.74509803921569 Accuracy = 0.72514071294559 Precision = 0.70183486238532 Recall = 0.74561403508772 F1 score = 0.72306238185255 Accuracy = 0.74074074074074 Precision = 0.7447963800905 Recall = 0.75228519195612 F1 score = 0.74852205547976 Accuracy = 0.72983114446529 Precision = 0.72205438066465 Recall = 0.70501474926254 F1 score = 0.7134328358209 Accuracy = 0.74777308954524 Precision = 0.77127172918573 Recall = 0.74535809018568 F1 score = 0.75809352517986
前回の記事に掲載したように、問題 76 では学習したデータそのものを予測して 99% 近い正解率が得られていましたが、今回の結果では 74% 程度の正解率になっています。これは過学習が起きていることが原因だと考えられるので、ロジスティック回帰に正則化項を導入して比較してみます。
問題 73 で実装した LogisticRegression クラスを継承して、L2 正則化を含む形のコスト関数と勾配を実装しました。
<?php require_once __DIR__ . '/LogisticRegression.php'; class LogisticRegressionL2 extends LogisticRegression { public $penalty; public function __construct($rate = 1.0, $threshold = 1e-3, $penalty = 1.0) { parent::__construct($rate, $threshold); $this->penalty = $penalty; } public function cost() { $cost = parent::cost(); $sum = 0; foreach ($this->weights as $word => $weight) { if ($word != '') { $sum += $weight * $weight; } } $cost += $this->penalty * $sum / 2; return $cost; } public function gradient() { $grad = parent::gradient(); foreach ($this->weights as $word => $weight) { if ($word != '') { $grad[$word] += $this->penalty * $weight; } } return $grad; } }
メインプログラムを以下のように書き換えて実行します。
$ diff main.php main_l2.php 3c3 < require_once __DIR__ . '/LogisticRegression.php'; --- > require_once __DIR__ . '/LogisticRegressionL2.php'; 15c15 < $logisticRegression = new LogisticRegression(1e-4, 1e-2); --- > $logisticRegression = new LogisticRegressionL2(1e-4, 1e-2, 1.0);
実行結果は以下のとおりです。正解率は少し高くなっているものの、それほど大きな差は見られませんでした。
$ cat result_l2_penalty_1/result.txt Accuracy = 0.76594746716698 Precision = 0.76014760147601 Recall = 0.77516462841016 F1 score = 0.7675826734979 Accuracy = 0.74108818011257 Precision = 0.71545454545455 Recall = 0.76705653021442 F1 score = 0.74035747883349 Accuracy = 0.75058602906704 Precision = 0.76018518518519 Recall = 0.75045703839122 F1 score = 0.75528978840846 Accuracy = 0.73545966228893 Precision = 0.72359328726555 Recall = 0.72074729596853 F1 score = 0.72216748768473 Accuracy = 0.74964838255977 Precision = 0.77923292797007 Recall = 0.73651635720601 F1 score = 0.75727272727273
79. 適合率-再現率グラフの描画
ロジスティック回帰モデルの分類の閾値を変化させることで,適合率-再現率グラフを描画せよ.
計算結果を CSV ファイルとして出力し、Excel でグラフを作成する方針で進めます。以下のようにプログラムを実装しました。判別結果を確率の降順に整列して、適合率と再現率、F1 スコアを計算していきます。
<?php main(); function main() { $data = read_predicted_data('predict.txt'); rsort($data); $npos = array_sum(array_column($data, 1)); $tp = 0; foreach ($data as $i => list($probability, $truth)) { $tp += $truth; $precision = $tp / $i; $recall = $tp / $npos; $f1score = 2 * $precision * $recall / ($precision + $recall); echo implode(',', [$probability, $precision, $recall, $f1score]), "\n"; } } function read_predicted_data($filename) { $data = []; $lines = file($filename, FILE_IGNORE_NEW_LINES); foreach ($lines as $line) { list ($truth, $predicted, $probability) = explode("\t", $line); $data[] = [(float) $probability, $truth == '+1' ? 1 : 0]; } return $data; }
問題 78 の交差検定で得られた結果の一つから描画したグラフを以下に示します。左のグラフは閾値を横軸に取って各指標の値をプロットしたもの、右のグラフが precision-recall 曲線を描いたものです。
*1:各 chunk のサイズを指定して分割するので、たとえば 10 要素の配列を 3 要素ずつ分割すると (3, 3, 3, 1) と分割されます。