³µÍ×

numpy¤ò»ÈÍѤ·¤¿½Å²óµ¢Ê¬ÀÏ ¤ä scikit-learn¤ò»ÈÍѤ·¤¿½Å²óµ¢Ê¬ÀÏ ¤Ç¼ÂÁõ¤·¤¿ÄÂÂß²Á³Ê¤Î¿äÏÀ¤ÈƱ¤¸»ö¤ò Chainer ¤ò»ÈÍѤ·¤Æ¤ä¤Ã¤Æ¤ß¤ë¡£
¢¨ 1ÁØ¡¢2ÆþÎÏ¡¢1½ÐÎϤΥ·¥ó¥×¥ë¤Ê¥Í¥Ã¥È¥ï¡¼¥¯¤È¤¹¤ë¡£
¢¨ »ÈÍѤ¹¤ë¥Ç¡¼¥¿¤Ï numpy¤ò»ÈÍѤ·¤¿½Å²óµ¢Ê¬ÀÏ ¤ÈƱ¤¸¡£

Ìܼ¡

¼ÂÁõ

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import Chain, Variable
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


class MyModel(Chain):
    def __init__(self):
        super().__init__(
            l1=L.Linear(2, 1)
        )
    def __call__(self, x):
        return self.l1(x)


# ³Ø½¬Î¨¡¢³Ø½¬²ó¿ô
alpha = 0.01
iter_num = 1000

# ¥Ç¡¼¥¿Æɤ߹þ¤ß
data = np.loadtxt("data/sample_rent1.csv", delimiter=",", skiprows=1)
x = data[:, 1:3].astype("float32")  # ¹­¤µ, ÃÛǯ¿ô
t = data[:, 3:4].astype("float32")  # ²ÈÄÂ

# ·±Îý¥Ç¡¼¥¿¤È¥Æ¥¹¥È¥Ç¡¼¥¿¤Ëʬ³ä
x_train, x_test, t_train, t_test = train_test_split(x, t, test_size=0.3, random_state=0)

# ɸ½à²½ÍÑ
scaler = StandardScaler()
scaler.fit(x)

# ¥â¥Ç¥ë¤Î¥¤¥ó¥¹¥¿¥ó¥¹²½
net = MyModel()
optimizer = chainer.optimizers.SGD(lr=alpha)  # ºÇŬ²½¼êË¡¤Ë¤Ï³ÎΨŪ¸ûÇ۹߲¼Ë¡ (SGD) ¤òÍøÍÑ
optimizer.setup(net)

# ¥í¥°¤ÎÊݸÍÑ
loss_history = []

# ÀºÅÙ¸þ¾å¤Î²Ä»ë²½ÍѤËǤ°Õ¤Î1¥ì¥³¡¼¥É¤ò¥µ¥ó¥×¥ë¤È¤·¤ÆÃê½Ð
x_sample = np.array([[60, 10]]).astype("float32")
sample_history = []
print("## ¥µ¥ó¥×¥ë ...  ¹­¤µ: {}­Ö, ÃÛǯ¿ô: {}ǯ ##".format(x_sample[0,0], x_sample[0,1]))

for epoch in range(iter_num):

    # Àµµ¬²½(ɸ½à²½)
    x_train_scaled = scaler.transform(x_train)
    x_test_scaled  = scaler.transform(x_test)

    # ¥â¥Ç¥ë¤Î·±Îý
    y_train = net(x_train_scaled)

    # ÀºÅÙ¤ò·×»»
    loss = F.mean_squared_error(y_train, t_train)  # Ê¿¶ÑÆó¾è¸íº¹
    loss_history.append(loss.item())

    # ¸ûÇۤΥꥻ¥Ã¥È¤È¸ûÇۤη׻»
    net.cleargrads()
    loss.backward()

    #  ¥Ñ¥é¥á¡¼¥¿¤Î¹¹¿·
    optimizer.update()

    # ¥µ¥ó¥×¥ë¥Ç¡¼¥¿¤Î¿äÏÀ
    with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
        # chainer.using_config('train', False) ... ÂоݤΥǡ¼¥¿¤ò»ÈÍѤ·¤¿·±Îý¤ò¹Ô¤ï¤Ê¤¤
        # chainer.using_config('enable_backprop', False) ... ·×»»¥°¥é¥Õ¤Î¹½ÃÛ¤ò¹Ô¤ï¤Ê¤¤
        x_sample_scaled = scaler.transform(x_sample)
        y_sample = net(x_sample_scaled)
        sample_history.append(y_sample.item())

    # 100·ï¤´¤È¤Ë¥µ¥ó¥×¥ë¤Î¿äÏÀ·ë²Ì µÚ¤Ó ÀºÅÙ¤òɽ¼¨
    if (epoch + 1) % 100 == 0:
        print("## {}: loss: {}, sample_result: {} ##".format(epoch + 1, loss.item(),  y_sample.item()))

print("## ·±ÎýºÑ¤ß¥Í¥Ã¥È¥ï¡¼¥¯¤ò»ÈÍѤ·¤ÆǤ°Õ¤ÎÃͤò¿äÏÀ¤·¤Æ¤ß¤ë ##")
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
    #x_tmp = x_test[:5]
    x_tmp = np.array([[60.0, 10.0],[50.0, 10.0],[40.0, 10.0]]).astype("float32")
    y_tmp = net(scaler.transform(x_tmp))
    [print("¹­¤µ: {:.1f}­Ö, ÃÛǯ¿ô: {}ǯ => ²ÈÄÂ: {:.1f}Ëü±ß".format(x_tmp[i,0], x_tmp[i,1], y_tmp.array[i,0])) for i in range(x_tmp.shape[0])]

# ·±ÎýºÑ¤ß¥Í¥Ã¥È¥ï¡¼¥¯¤ÎÊݸ
chainer.serializers.save_npz('sample1.net', net)

# ·±ÎýºÑ¤ß¥Í¥Ã¥È¥ï¡¼¥¯¤ÎÍøÍÑ
#loaded_net = MyModel()
#chainer.serializers.load_npz('sample1.net', loaded_net)
#with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
#    y_test = loaded_net(x_test)

# ÆüËܸì¥Õ¥©¥ó¥È¤òÍøÍѲÄǽ¤Ë¤·¤Æ¤ª¤¯
plt.rcParams['font.sans-serif'] = ['Hiragino Maru Gothic Pro', 'Yu Gothic', 'Meirio', 'Takao', 'IPAexGothic', 'IPAPGothic', 'Noto Sans CJK JP']

#fig = plt.figure()
fig = plt.figure(figsize=(10, 5))
x1 = fig.add_subplot(1, 2, 1)
x1.set_title("·±Îý²ó¿ôËè¤ÎÀºÅÙ")
x1.plot(loss_history)
x1.set_xlabel("³Ø½¬²ó¿ô")
x1.set_ylabel("ÀºÅÙ(Loss)")
x1.grid(True)

x2 = fig.add_subplot(1, 2, 2)
x2.set_title("¹­¤µ: 60  ÃÛǯ¿ô: 10ǯ¤Î¾ì¹ç¤Î²ÈÄÂ")
x2.set_xlabel("³Ø½¬²ó¿ô")
x2.set_ylabel("²ÈÄÂ")
x2.plot(sample_history)
x1.grid(True)

plt.show()

·ë²Ì

numpy¤ò»ÈÍѤ·¤¿½Å²óµ¢Ê¬ÀÏ ¤ä scikit-learn¤ò»ÈÍѤ·¤¿½Å²óµ¢Ê¬ÀÏ ¤Î·ë²Ì¤È¤Û¤ÜƱ¤¸¡£
º¸¤Î¥°¥é¥Õ¤«¤é·±ÎýËè¤ËÀºÅÙ¤¬¾å¤¬¤Ã¤Æ¤¤¤ë»ö¤¬Ê¬¤«¤ë¡£
¤Þ¤¿¡¢±¦¤Î¥°¥é¥Õ¤«¤é¤Ï¥µ¥ó¥×¥ë¥Ç¡¼¥¿(¹­¤µ:60­Ö¡¢ÃÛǯ¿ô: 10ǯ) ¤Î²ÈĤ¬³Ø½¬Ëè¤ËÌó8Ëü±ß¤¢¤¿¤ê¤Ë¼ý«¤·¤Æ¤¤¤Ã¤Æ¤¤¤ë»ö¤¬Ê¬¤«¤ë¡£

## ¥µ¥ó¥×¥ë ...  ¹­¤µ: 60.0­Ö, ÃÛǯ¿ô: 10.0ǯ ##
## 100: loss: 1.2375797033309937, sample_result: 6.949620246887207 ##
## 200: loss: 0.5576881170272827, sample_result: 7.930976867675781 ##
## 300: loss: 0.5441631078720093, sample_result: 8.066993713378906 ##
## 400: loss: 0.5438786745071411, sample_result: 8.086057662963867 ##
## 500: loss: 0.5438722372055054, sample_result: 8.088766098022461 ##
## 600: loss: 0.5438721179962158, sample_result: 8.089155197143555 ##
## 700: loss: 0.5438721179962158, sample_result: 8.08920669555664 ##
## 800: loss: 0.5438721179962158, sample_result: 8.089208602905273 ##
## 900: loss: 0.5438721179962158, sample_result: 8.089208602905273 ##
## 1000: loss: 0.5438721179962158, sample_result: 8.089208602905273 ##
## ·±ÎýºÑ¤ß¥Í¥Ã¥È¥ï¡¼¥¯¤ò»ÈÍѤ·¤ÆǤ°Õ¤ÎÃͤò¿äÏÀ¤·¤Æ¤ß¤ë ##
¹­¤µ: 60.0­Ö, ÃÛǯ¿ô: 10.0ǯ => ²ÈÄÂ: 8.1Ëü±ß
¹­¤µ: 50.0­Ö, ÃÛǯ¿ô: 10.0ǯ => ²ÈÄÂ: 7.2Ëü±ß
¹­¤µ: 40.0­Ö, ÃÛǯ¿ô: 10.0ǯ => ²ÈÄÂ: 6.2Ëü±ß

sample_rent1_chainer.png


źÉÕ¥Õ¥¡¥¤¥ë: filesample_rent1_chainer.png 325·ï [¾ÜºÙ]

¥È¥Ã¥×   º¹Ê¬ ¥Ð¥Ã¥¯¥¢¥Ã¥× ¥ê¥í¡¼¥É   °ìÍ÷ ñ¸ì¸¡º÷ ºÇ½ª¹¹¿·   ¥Ø¥ë¥×   ºÇ½ª¹¹¿·¤ÎRSS
Last-modified: 2019-11-22 (¶â) 08:07:51 (1756d)