kerasを使ったreuter記事分類のexampleをなぞる

github.com

kerasにはデータセットが用意されています1

今回はその中からReuter記事データの分類をしてみます。 基本的にはこちらのexampleに示されているコードをなぞる形です。

kerasのインストール

pythonは3系を使います。

TensorFlowのインストール

keras自身のインストールに先立ち、バックエンドとして使用する機械学習ライブラリをインストールする必要があります。TensorFlowやTheano,CNTKなどから選ぶ事ができますが2、今回はTensorFlowを利用します。

$ sudo apt-get install python3-pip python3-dev
$ sudo pip3 install tensorflow

kerasのインストール

$ sudo pip3 install keras

Reuter記事の分類

データセットのロード

from keras.datasets import reuters
(x_train, y_train), (x_test, y_test) = reuters.load_data(num_words=1000, test_split=0.2)

記事のデータセットエンコードされ、配列になっています。それぞれの単語にはデータセット全体の中で頻度の多い方から順にインデックス番号が振られています。

x_train.shape # (8982,)
x_train[0][:10] # [1, 4, 2, 2, 9, 697, 2, 111, 8, 25]

引数のnum_wordsは最頻出上位○○位の値、test_splitはtrainデータとtestデータの割合(今回は4:1)を指定しています。

データの前処理

入力データ(x)は配列から行列の表現に変換する必要があります。

from keras.preprocessing.text import Tokenizer
tokenizer = Tokenizer(num_words=1000)
x_train = tokenizer.sequences_to_matrix(x_train, mode='binary')
x_test  = tokenizer.sequences_to_matrix(x_test, mode='binary')

tokenizerのイニシャライザには先ほどと同じ値を指定します。sequences_to_matrixmodeにbinaryを指定すると、文中にインデックス番号の単語が含まれていれば1、そうでなければ0が値となるベクトルに変換されます。

教師データ(y)はone-hot表現に変換します。

from keras.utils import to_categorical
y_train = to_categorical(y_train)
y_test  = to_categorical(y_test)

モデルの構築

exampleよりもシンプルなモデルにします。出力のサイズ(分類の数)は46です。

from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
model.add(Dense(512, input_shape=(1000,), activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(46, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

学習

model.fit(x_train, y_train, batch_size=32, epochs=5, verbose=2, validation_data=(x_test, y_test))

(メモ: GCEのn1-standard-1タイプインスタンスで各エポックあたり6s)

評価

score = model.evaluate(x_test, y_test, verbose=0)
score[0] # 損失関数の値: 1.0874391948125879
score[1] # 正答率: 0.7738201247191493

コード全体 → reuter_mnist.py · GitHub

疑問

  • modeの値にbinary以外を指定するとどうなるんだろう...

Refs