PHP-ML で手書き数字認識 (勉強会発表資料)
PHP-ML という機械学習ライブラリを使って、MNIST の手書き数字認識を試してみました。11 月 29 日に開催された 第120回 PHP勉強会 の LT で発表しましたので、スライドを公開します。
勉強会の LT 発表では、学習したモデルで手書き数字の認識を試せるように、簡単なデモを作成して紹介しました。下記のウェブサイトで、マウスで適当な数字を描いて「実行」ボタンをクリックすると、判別結果が表示されます。
手書き数字認識デモ
PHP-ML は、PHP で実装された機械学習ライブラリです。現時点で、分類、回帰、クラスタリング等の代表的なアルゴリズムがいくつか実装されています。ほぼすべて PHP で実装されているため*1、処理速度、メモリ消費量の両面で大規模データでの実用は難しいという印象ですが、簡単な問題に適用して遊んでみたり、コードを読んで勉強してみたりするには良いかもしれません。
PHP-ML - Machine Learning library for PHP
今回は、ロジスティック回帰を MNIST データセットに適用して、モデル学習、判別を行いました。多クラス分類については one-vs-rest の手法が組み込まれており、利用側での対応は不要です。一方、現時点ではロジスティック回帰の実装にいくつかバグがあるようで、正しく動作させるためにはいくつかの修正が必要でした。修正した状態のコードは私のリポジトリにあります。
[2017-12-05 追記] master ブランチで修正済みです。
勉強会の LT では、残念ながら実装内容まで立ち入って説明することができなかったので、この記事で詳しく説明します。
MNIST データセットを PHP-ML の入力形式に変換する
まず、今回利用する MNIST のデータセットをウェブサイトからダウンロードします。
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
PHP-ML では、CsvDataset クラスを用いて、特徴量とラベルからなる CSV 形式のデータセットを読み込めます。
CSV Dataset - PHP-ML - Machine Learning library for PHP
そこで、ダウンロードした MNIST のデータセットを展開して、この形式の CSV ファイルに変換します。次のようなスクリプトで変換できます。このあたりの詳細は過去の記事で紹介しています。各ピクセルの値は 0 から 255 の範囲にあるので、255 で割って 0 から 1 の浮動小数点数に変換しておきます。
MNIST 手書き数字データを画像ファイルに変換する - y_uti のブログ
#!/bin/bash for l in train t10k; do paste -d' ' \ <(gzip -dc $l-images-idx3-ubyte.gz | od -An -v -tu1 -j16 -w784) \ <(gzip -dc $l-labels-idx1-ubyte.gz | od -An -v -tu1 -j8 -w1) \ | awk -F, '{ for (i = 1; i < NF; i++) printf("%f,", $i / 255); printf("%d\n", $NF); }' >$l.csv done
変換結果は、各行が 785 列の CSV ファイルになります。先頭から 784 列目までは、28x28 のグレースケール画像における各ピクセルの値を表します。最後の列は、その画像が 0 から 9 のどの数字であるかを表します。
モデルを学習する
モデル学習を行うプログラムは、以下のように実装できます*2。
<?php require_once __DIR__ . '/vendor/autoload.php'; use \Phpml\Classification\Linear\LogisticRegression; use \Phpml\Dataset\CsvDataset; use \Phpml\ModelManager; // 学習用のデータセットを読み込む $dataset = new CsvDataset('train.csv', 784, false); // モデルを学習する $classifier = new LogisticRegression(100, false, LogisticRegression::BATCH_TRAINING); $classifier->train($dataset->getSamples(), $dataset->getTargets()); // 学習によって得られたモデルを保存する $modelManager = new ModelManager(); $modelManager->saveToFile($classifier, 'model.dat');
最初に、CSV ファイルからデータセットを読み込みます。先頭の二つの引数は必須で、ファイル名と特徴量の次元数を指定します。第三引数の false は、CSV ファイルがヘッダ行を含まないことを指定します。
学習処理の本体では、LogisticRegression クラスのインスタンスを生成して、学習を実行します。コンストラクタの引数に、最大の反復回数、データの正規化の有無、学習方法を指定しています。第三引数には ONLINE_TRAINING, CONJUGATE_GRAD_TRAINING も指定できますが、私が今回試した限りでは CONJUGATE_GRAD_TRAINING は上手く動きませんでした。このほかに、第四引数以降でコスト関数や正則化項も指定できます。詳細はソースコードのコメントを参照してください。
php-ml/LogisticRegression.php at 0.5.0 · php-ai/php-ml · GitHub
学習済みのモデルは、ModelManager クラスを利用してファイルに保存できます。モデルを保存しておくことで、モデル学習処理と学習結果を用いた判別処理を分離できます。
このように簡単なプログラムでモデル学習を行えますが、これを実際に動かすには大きな計算リソースが必要です。私が試したところでは、PHP 7.2.0 RC5 で memory_limit=1024M とした環境で train.csv はメモリ不足になりました。
判別する
学習済みのモデルを利用してテストデータを判別する処理は、次のように実装できます。判別処理の本体はプログラムの前半部分で、後半は判別精度を評価する処理です。
<?php require_once __DIR__ . '/vendor/autoload.php'; use \Phpml\Classification\Linear\LogisticRegression; use \Phpml\Dataset\CsvDataset; use \Phpml\Metric\Accuracy; use \Phpml\Metric\ConfusionMatrix; use \Phpml\ModelManager; // 判別対象のデータセットを読み込む $dataset = new CsvDataset('t10k.csv', 784, false); // 学習済みのモデルを読み込む $modelManager = new ModelManager(); $classifier = $modelManager->restoreFromFile('model.dat'); // 判別する $predicted = $classifier->predict($dataset->getSamples()); // 正解率を計算する $accuracy = Accuracy::score($dataset->getTargets(), $predicted); echo 'Accuracy = ', $accuracy * 100, "%\n"; // 混同行列を計算する $confusionMatrix = ConfusionMatrix::compute($dataset->getTargets(), $predicted); foreach ($confusionMatrix as $row) { array_walk($row, function ($col) { printf('%4d ', $col); }); echo "\n"; }
最初に、CSV ファイルからデータセットを読み込みます。これはモデル学習と変わりません。次に、ModelManager クラスを利用して、ファイルに保存していたモデルを読み込みます。これらを利用して、LogisticRegression クラスの predict メソッドで判別処理を実行します。判別結果として得られる $predicted は、各要素が 0 から 9 のいずれかの値を持つ配列です。データセットの各サンプルに対する判別結果を表します。
プログラムの後半では、Accuracy クラスと ConfusionMatrix クラスを用いて、正解率、混同行列を計算して表示します。このような指標によって、学習されたモデルの良し悪しを評価できます。
Accuracy - PHP-ML - Machine Learning library for PHP
Confusion Matrix - PHP-ML - Machine Learning library for PHP
実行例
以上のプログラムを実行した様子を示します。私の環境では 60,000 文字のデータを学習できるだけのリソースが無かったため、CSV の先頭から 10,000 行を取り出して学習データとします。
$ mv train.csv train-all.csv $ head -n 10000 train-all.csv >train.csv
作成したプログラムを用いてモデル学習を行い、続いて t10k.csv を判別します。
$ php train.php $ php predict.php Accuracy = 84.95% 953 0 3 6 1 4 10 2 1 0 0 1105 14 3 0 6 4 0 3 0 5 9 905 40 11 6 20 21 12 3 5 1 25 916 3 42 3 10 2 3 1 1 11 2 926 4 10 6 1 20 15 4 5 41 16 780 14 7 6 4 8 3 14 2 11 25 892 3 0 0 6 9 39 7 5 5 1 945 0 11 14 14 109 94 18 285 19 28 388 5 9 7 10 24 121 55 1 97 0 685
正解率は 84.95% となりました。混同行列は、各行が正解ラベル、各列が判別結果を表します。それぞれ 0 から 9 の順です。たとえば 2 行 3 列の 14 という値は、実際は 1 であるが 2 と判別されたものが 14 サンプル存在したという意味です。対角線上は実際の数字と判別結果が一致しているもので、これらは判別に成功したことになります。対角線上の数値を合計すると 8,495 となり、全体で 10,000 サンプルなので正解率 84.95% になります。混同行列を見ることで、全体の正解率だけではなく "1" は良く判別できているが "8" は極端に判別精度が低いといった情報を読み取れます。
ミニバッチ学習
モデル学習の際に train メソッドに代えて partialTrain メソッドを用いると、これまでの学習結果を初期状態として学習を続けることができます。この機能を利用してミニバッチ学習を試してみます。まず、train.csv を適当に分割します。ここでは split コマンドを用いて 1,000 行ずつ分割しました。train-00 から train-59 までの 60 ファイルが生成されます。
$ split -l 1000 -d train.csv train-
ミニバッチ学習の実装は以下のとおりです。最初に LogisticRegression クラスのインスタンスを生成して、二重のループで学習を行います。内側のループでは、1,000 サンプルずつに分割した各データセットを順に読み込んでモデル学習を行います。これを 10 回繰り返して全体の学習としています。
<?php require_once __DIR__ . '/vendor/autoload.php'; use \Phpml\Classification\Linear\LogisticRegression; use \Phpml\Dataset\CsvDataset; use \Phpml\ModelManager; $classifier = new LogisticRegression(1, false, LogisticRegression::BATCH_TRAINING); for ($n = 0; $n < 10; ++$n) { for ($k = 0; $k < 60; ++$k) { $dataset = new CsvDataset(__DIR__ . '/train-' . sprintf('%02d', $k), 784, false); $classifier->partialTrain($dataset->getSamples(), $dataset->getTargets()); } } $modelManager = new ModelManager(); $modelManager->saveToFile($classifier, 'model-minibatch.dat');
実行例は以下のとおりです*3。ミニバッチ学習では全体として 60,000 サンプルを学習しているため、より高い精度を得ることができました。
$ php train-minibatch.php $ php predict-minibatch.php Accuracy = 91.08% 961 0 0 1 0 2 10 2 4 0 0 1102 2 2 1 2 4 1 21 0 11 11 886 19 12 6 14 15 50 8 7 0 17 909 2 27 5 11 25 7 4 3 5 1 908 2 9 1 13 36 11 4 2 28 11 772 20 7 28 9 14 3 2 2 11 18 902 1 5 0 5 14 20 5 9 3 3 935 4 30 9 10 6 16 14 35 12 9 857 6 10 9 2 13 47 13 1 25 13 876