概要

Chainerで簡単な分類予測を実装してみる。
使用するデータは iris。
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html

目次

実装

import numpy as np
import matplotlib.pyplot as plt
import chainer
import chainer.links as L
import chainer.functions as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from chainer import Sequential


# irisデータを読み込んでデータ型を変換
x, t = load_iris(return_X_y=True)
x = x.astype('float32')
t = t.astype('int32')

# 訓練データと検証データに分割
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.3, random_state=0)

# 正規化用(標準化用)
scaler = StandardScaler()
scaler.fit(x)

# --------------------------------------------
# ネットワークを定義
# --------------------------------------------

# 学習率、学習回数
alpha = 0.01
iter_num = 1000

n_input = 4
n_hidden = 10
n_output = 3

class MyModel(chainer.Chain):
    """ネットワーク."""

    def __init__(self):
        """コンストラクタ."""
        super().__init__(
            l1=L.Linear(n_input, n_hidden),
            l2=L.Linear(n_hidden, n_hidden),
            l3=L.Linear(n_hidden, n_output)
        )

    def forward(self, x, t=None):  # Chainer v5では、__ call__ の代わりに forwardを使用することが推奨されている模様
        """順伝番の実施."""
        x = x.astype("float32")
        t = t.astype("int32") if t is not None else None

        # 順伝番
        h = self.acc(self.l1(x))
        h = self.acc(self.l2(h))
        y = self.l3(h)

        # コスト 及び 精度の計算
        if t is not None:
            loss = F.softmax_cross_entropy(y, t)  # 目的関数は交差エントロピーを採用
            acc = F.accuracy(y, t)                              # 分類精度を計算
            return y, loss, acc
        else:
            return y

    def acc(self, h):
        """活性化関数."""
        return F.relu(h)  # ReLU 関数
        #return F.sigmoid(h)  # シグモイド関数

net = MyModel()

# 上記と同義、目的関数などはモデル定義に含まない
"""
net = Sequential(
    L.Linear(n_input, n_hidden), F.relu,
    L.Linear(n_hidden, n_hidden), F.relu,
    L.Linear(n_hidden, n_output)
)
"""

# 最適化手法として確率的勾配降下法(SGD)を採用
optimizer = chainer.optimizers.SGD(lr=alpha)
optimizer.setup(net)

# --------------------------------------------
# ネットワークの訓練
# --------------------------------------------

# 目的関数の出力と分類精度の保存用
loss_history = []
accuracy_history = []

for epoch in range(iter_num):

    # 正規化(標準化)
    x_train_scaled = scaler.transform(x_train)

    """
    # モデルの訓練
    #y_train = net(x_train_scaled)

    # 目的関数の適用 及び 分類精度の計算
    loss_train = F.softmax_cross_entropy(y_train, t_train)  # 目的関数は交差エントロピーを採用
    accuracy_train = F.accuracy(y_train, t_train)                    # 分類精度を計算
    """

    # モデルの訓練
    y_train, loss_train, accuracy_train = net(x_train_scaled, t_train)
    
    # 履歴に追加
    loss_history.append(loss_train.array)
    accuracy_history.append(accuracy_train.array)

    # 100件ごとに精度を表示
    if epoch == 0 or (epoch + 1) % 100 == 0:
        print('epoch: {}, loss (train): {:.4f}, accuracy(train): {:.4f}'.format(
            epoch+1, loss_train.array, accuracy_train.array))

    # 勾配のリセットと勾配の計算
    net.cleargrads()
    loss_train.backward()

    #  パラメータの更新
    optimizer.update()

# 訓練済みネットワークを保存
chainer.serializers.save_npz('chainer_classification1', net)

# --------------------------------------------
# 検証データで予測してみる
# --------------------------------------------

with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
    x_test_scaled  = scaler.transform(x_test)

    # 検証データで推論
    #y_test = net(x_test_scaled)

    y_test, loss_test, accuracy_test = net(x_test_scaled, t_test)
    y_test_val = [np.argmax(y_test.array[i]) for i in range(y_test.shape[0])]

    # 目的関数を適用し、分類精度を計算
    #loss_test = F.softmax_cross_entropy(y_test, t_test)
    #accuracy_test = F.accuracy(y_test, t_test)

    print("検証データでの分類精度: {:.4f}".format(accuracy_test.data))
    for i in range(y_test.shape[0]):
        print("t_test[{}]: {}, y_test[{}]: {}, same: {}".format(i, t_test[i], i, y_test_val[i], t_test[i] == y_test_val[i]))

    # 検証データの混同行列を確認
    matrix_test = confusion_matrix(y_test_val, t_test)
    print(matrix_test)

    # 検証データの正解率
    #print(accuracy_score(y_test_val, t_test))


# --------------------------------------------
# 目的関数 及び 分類精度のグラフ出力
# --------------------------------------------

fig = plt.figure(figsize=(10, 5))
                 
# 目的関数(loss)
x1 = fig.add_subplot(1, 2, 1)
x1.plot(loss_history)
x1.set_xlabel("iterations")
x1.set_ylabel("loass")
x1.grid(True)

# 分類精度 (accuracy)
x2 = fig.add_subplot(1, 2, 2)
x2.plot([val * 100 for val in accuracy_history])
x2.set_xlabel("iterations")
x2.set_ylabel("accuracy")
x2.grid(True)

plt.show()

結果

epoch: 1, loss (train): 1.3934, accuracy(train): 0.1619
epoch: 100, loss (train): 0.8345, accuracy(train): 0.5714
epoch: 200, loss (train): 0.6947, accuracy(train): 0.7905
epoch: 300, loss (train): 0.5798, accuracy(train): 0.8857
epoch: 400, loss (train): 0.4589, accuracy(train): 0.8952
epoch: 500, loss (train): 0.3476, accuracy(train): 0.9048
epoch: 600, loss (train): 0.2640, accuracy(train): 0.9048
epoch: 700, loss (train): 0.2086, accuracy(train): 0.9333
epoch: 800, loss (train): 0.1725, accuracy(train): 0.9619
epoch: 900, loss (train): 0.1478, accuracy(train): 0.9619
epoch: 1000, loss (train): 0.1299, accuracy(train): 0.9619

検証データでの分類精度: 0.9556
t_test[0]: 2, y_test[0]: 2, same: True
t_test[1]: 1, y_test[1]: 1, same: True
t_test[2]: 0, y_test[2]: 0, same: True
t_test[3]: 2, y_test[3]: 2, same: True
t_test[4]: 0, y_test[4]: 0, same: True
t_test[5]: 2, y_test[5]: 2, same: True
t_test[6]: 0, y_test[6]: 0, same: True
t_test[7]: 1, y_test[7]: 1, same: True
t_test[8]: 1, y_test[8]: 1, same: True
t_test[9]: 1, y_test[9]: 1, same: True
t_test[10]: 2, y_test[10]: 1, same: False
t_test[11]: 1, y_test[11]: 1, same: True
t_test[12]: 1, y_test[12]: 1, same: True
t_test[13]: 1, y_test[13]: 1, same: True
t_test[14]: 1, y_test[14]: 1, same: True
t_test[15]: 0, y_test[15]: 0, same: True
t_test[16]: 1, y_test[16]: 1, same: True
t_test[17]: 1, y_test[17]: 1, same: True
t_test[18]: 0, y_test[18]: 0, same: True
t_test[19]: 0, y_test[19]: 0, same: True
t_test[20]: 2, y_test[20]: 2, same: True
t_test[21]: 1, y_test[21]: 1, same: True
t_test[22]: 0, y_test[22]: 0, same: True
t_test[23]: 0, y_test[23]: 0, same: True
t_test[24]: 2, y_test[24]: 2, same: True
t_test[25]: 0, y_test[25]: 0, same: True
t_test[26]: 0, y_test[26]: 0, same: True
t_test[27]: 1, y_test[27]: 1, same: True
t_test[28]: 1, y_test[28]: 1, same: True
t_test[29]: 0, y_test[29]: 0, same: True
t_test[30]: 2, y_test[30]: 2, same: True
t_test[31]: 1, y_test[31]: 1, same: True
t_test[32]: 0, y_test[32]: 0, same: True
t_test[33]: 2, y_test[33]: 2, same: True
t_test[34]: 2, y_test[34]: 2, same: True
t_test[35]: 1, y_test[35]: 1, same: True
t_test[36]: 0, y_test[36]: 0, same: True
t_test[37]: 1, y_test[37]: 2, same: False
t_test[38]: 1, y_test[38]: 1, same: True
t_test[39]: 1, y_test[39]: 1, same: True
t_test[40]: 2, y_test[40]: 2, same: True
t_test[41]: 0, y_test[41]: 0, same: True
t_test[42]: 2, y_test[42]: 2, same: True
t_test[43]: 0, y_test[43]: 0, same: True
t_test[44]: 0, y_test[44]: 0, same: True
[[16  0  0]
 [ 0 17  1]
 [ 0  1 10]]

chainer_classification.png


添付ファイル: filechainer_classification.png 292件 [詳細]

トップ   差分 バックアップ リロード   一覧 単語検索 最終更新   ヘルプ   最終更新のRSS
Last-modified: 2019-11-28 (木) 21:18:44 (1833d)