PythonでChainerを使った手書き数字識別

2018/12/17 通信処理ネットワーク研究室 河田直樹

お品書き

今回の流れ

今回は Chainerで手書き数字の学習と実際に手書き文字を識別する体験 を行う.

学習ではMNIST(http://yann.lecun.com/exdb/mnist/)をデータセットとして用いる.

学習後,そのモデルを用いて,各自が書いた数字の識別テストをやってもらう.

1. 環境作り

Chainer (チェイナー) は,ニューラルネットワークの計算および学習を行うためのオープンソースソフトウェアライブラリである.

Pythonではパッケージをインストールするだけで環境作りができる.

グラフ表示のためにmatplotlib,識別テスト用の画像を書くためにPintaを入れてもらう.

ここではその導入方法を紹介する.

ライブラリのインストール

ラズパイのターミナル(端末)を起動し,以下のコマンドを入力する.

$ sudo pip install chainer==1.24.0
$ sudo apt-get install python-matplotlib
$ sudo apt-get install pinta

環境作りおしまい.

2. 学習

インストールしたChainerを利用して学習を行う.

その前に,少しディープラーニングについて説明したいと思う.


ディープラーニング(深層学習)って?

大雑把に言うと,入力(学習データ)に対して複数の条件(重みやバイアス)をかけて特徴を調べ(順伝播),正しい出力と比較し(逆伝播), 正しい出力が得られるように特徴を調節して(update),指定された回数(epoc数)まで繰り返す.

図3


MNISTについて

手書き数字の画像セット.機械学習の分野で最も有名なデータセットの1つ.

データセットは70000枚.(訓練データ: 60000枚,テストデータ: 10000枚)

http://yann.lecun.com/exdb/mnist/


実際に学習させてみる

Chainerの公式サンプル(https://github.com/chainer/chainer/blob/master/examples/mnist/train_mnist.py)をそのまま利用すると授業時間中に学習が終わらない.

そこで今回はソースコードに修正を加えて,0〜9ではなく,0か1かを学習させることにした.

修正したファイルをwgetでダウンロードする

$ wget http://icrus.org/kawada/dsp/chainer/train.py

ダウンロードしたファイルを実行

$ python train.py

実行結果

	GPU: -1
	# unit: 50
	# Minibatch-size: 50
	# epoch: 5

	epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
	1           0.0932722   0.00282592            0.990551       0.999535                  21.2169
	2           0.00367293  0.00128751            0.998893       0.999535                  44.2907
	3           0.00252816  0.00144561            0.999209       0.999535                  67.4477
	4           0.00163583  0.00117042            0.999606       0.999535                  90.8876
	5           0.00111253  0.00115027            0.999526       0.999535                  114.279
	save the trained model: result/MLP.model

学習自体は2分ほどで終わるが,初回はMNISTデータセットのダウンロードに時間がかかる.


プログラムの解説

このサンプルプログラムについて,わかりやすい記事(http://ailaby.com/chainer_mnist/)

 →ただし,今回はサンプルプログラムと違い,2つの数字のみのデータしか扱っていません.引数のデフォルトの数値も変えています.

ニューラルネットワークについて,わかりやすい記事(https://qiita.com/t-tkd3a/items/9bf50f2e10e6a15b6ed5)

3. テスト用画像の準備

Pintaを起動する

スタートメニュー>グラフィック>Pintaを選択して起動する.

キャンバスサイズの変更

入力画像をデータセットと同じ条件にするため変更する

[Shift+Ctrl+R] もしくは 画像(I)>キャンバスサイズの変更 をクリックしてオプション画面を呼び出す.

絶対サイズ にチェック

縦横比を保持する のチェックを外す

幅と高さ を 280ピクセル に指定する

OK をクリック

図1

ブラシサイズを変更

ツールバー(左)よりブラシツールを選択する.

メニューバー(上)付近のブラシサイズより2(デフォルト)から25に変更する.

図2

任意の数字を書く(今回は0か1)

図3

画像を保存する

メニューバーのファイル(F)から名前をつけて保存>

場所はpiのホームディレクトリを指定

保存形式はpngを指定

ファイル名はinputとすること(評価時に読み込むため)

図4

4. テストの実行

piのホームディレクトリの保存したinput.pngを参照して,学習済みモデル(result/MLP.model)より数字がいくつかを当てます.

今回使用するプログラムはInterFace2017年8月号で使用されたものに少し手を加えたものです.

wgetでダウンロード

$ wget http://icrus.org/kawada/dsp/chainer/eva.py

eva.pyはtrain.pyと同じディレクトリに保存してください.

ファイル確認

$ pwd
/home/pi
$ ls
input.png eva.py result train.py (その他色々,この4つあればいい)

テストの実行

$ python eva.py

実行結果

判定結果は1です.

5. おまけ

今回用意した train.py に引数を指定することで,0と1以外の組み合わせを指定できます.

指定方法: -n 1つ目の数字 -s 2つ目の数字

実行例

$ python train.py -n 2 -s 5

実行結果

GPU: -1
# unit: 50
# Minibatch-size: 50
# epoch: 5

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.17343     0.0302544             0.962281       0.988675                  18.9174
2           0.0275413   0.0170461             0.990526       0.994359                  39.8739
3           0.0144924   0.0108288             0.995154       0.997436                  60.8911
4           0.0081583   0.0145591             0.997456       0.993291                  83.394
5           0.00507223  0.00592832            0.998502       0.997949                  104.536
save the trained model: result/MLP.model

これでpintaでinput画像に2か5を書き,保存してeva.pyで評価することにより確認できる.

参考文献

  1. Chainer公式サンプルMNISTを読み解いてみた (http://okkah.hateblo.jp/entry/2018/05/01/202346)
  2. Chainer – ドットツールズ (https://atl2.net/chainer/)
  3. Mind で Neural Network (準備編2) 順伝播・逆伝播 図解 (https://qiita.com/t-tkd3a/items/9bf50f2e10e6a15b6ed5)
  4. softmax関数を直感的に理解したい (https://qiita.com/rtok/items/b1affc619d826eea61fd)
  5. Chainerのソースを解析。MNISTサンプルを追ってみる (http://ailaby.com/chainer_mnist/)
dspトップへ