Deep

Deepラーニングのメモです

200 views

pytorchの重みをカスタマイズする方法

# coding: UTF-8
import numpy as np
import torch
import torchvision.transforms.functional
from torch import nn
from torch import optim


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        middle_num = 4

        dummy_weight = [
            [1., 0.],
            [1., 0.],
            [1., 0.],
            [1., 0.],
        ]

        dummy_weight = torch.tensor(dummy_weight)
        self.dense = nn.Linear(2, middle_num)
        print(self.dense.weight)
        #torch.nn.init.ones_(self.dense.weight)
        #torch.nn.init.ones_(self.dense.bias)
        print(self.dense.state_dict())
        self.dense.weight = nn.Parameter(dummy_weight)
        print(self.dense.weight)
        #print(self.dense.weight)
        self.relu = nn.ReLU()
        self.out = nn.Linear(middle_num, 1)
        #torch.nn.init.ones_(self.out.weight)
        #torch.nn.init.ones_(self.out.bias)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.dense(x)
        print("↓↓↓↓↓↓↓↓↓↓")
        print(self.dense.weight)
        x = self.relu(x)
        x = self.out(x)
        x = self.sigmoid(x)
        return x


if __name__ == '__main__':
    # XORのデータを作成
    X = [[0., 0.],
         [0., 1.],
         [1., 0.],
         [1., 1.]]

    Y = [
        [0.],
        [1.],
        [1.],
        [0.]
    ]

    X = torch.tensor(X)
    Y = torch.tensor(Y)

    net = MyNet()
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.01)

    for epoch in range(100):
        optimizer.zero_grad()
        y_pred = net(X)
        print("result:{}".format(y_pred))

        loss = loss_fn(y_pred, Y)
        loss.backward()
        print("loss:{:.3f}".format(loss))
        optimizer.step()

27行目、nn.Parameterで値をすり替えることがポイント。

Page 21 of 29.

前のページ 次のページ



[添付ファイル]


お問い合わせ

プロフィール

マッスル

自己紹介

本サイトの作成者。
趣味:プログラム/水耕栽培/仮想通貨/激辛好き
プログラムは趣味と勉強を兼ねて、のんびり本サイトを作っています。
フレームワークはdjango。
仮想通貨はNEMが好き。
水耕栽培は激辛好きが高じて、キャロライナ・リーパーの栽培にチャレンジ中。

サイト/ブログ

https://www.osumoi-stdio.com/pyarticle/

ツイッター

@darkimpact0626