Deep

Deepラーニングのメモです

786 views

入力が1チャンネル、出力が1チャンネルの場合

コンボリューション層で、どのような計算を行っているのか、pytorchを用いて確認する。
最終目標はconvの1×1が一体何をやっているのかを知ることが目的。

まずは1チャンネル3×3の仮想画像データを作成する。

import numpy as np
import torch
import torch.nn as nn


def conv_test_net():

    number_net = nn.Sequential(
        nn.Conv2d(1, 1, 3, padding=1),
    )

    return number_net


if __name__ == '__main__':
    # 1チャネル 3×3の画像データを1枚作成
    virtual_img = [
            [
                [
                    [0,0,0],
                    [0,1,0],
                    [0,0,0]
                ]
            ]

        ]

    t = torch.tensor(virtual_img, dtype=torch.float)
    net = conv_test_net()
    conv = net[0]
    print("[weight]")
    print(conv.weight)
    print("[bias]")
    print(conv.bias)

    y = net(t)

    print(y.size())
    print(y)

10行目で入力1チャンネル、出力1枚、3×3のコンボリューション層を作成している。
18行目はテスト用の画像である。1×1×3×3のデータを作成している。
29行目でtorch用のデータに変換している。
30行目はnetworkを作成している。
33行目でコンボリューション層のweightを、35行目でbiasを表示している。
37行目でデータをネットワークに投入している。
39行目で出力のサイズを確認している。
40行目はコンボリューション層の出力結果を確認している。

実行した出力を以下に記す。

[weight]
Parameter containing:
tensor([[[[ 0.2893,  0.1318,  0.1259],
          [-0.2801,  0.0896, -0.3314],
          [ 0.2742,  0.1006,  0.2215]]]], requires_grad=True)
[bias]
Parameter containing:
tensor([0.3101], requires_grad=True)
torch.Size([1, 1, 3, 3])
tensor([[[[ 0.5316,  0.4107,  0.5843],
          [-0.0213,  0.3997,  0.0300],
          [ 0.4360,  0.4419,  0.5994]]]], grad_fn=<MkldnnConvolutionBackward>)

どういう計算をしているかというと、


出力の1行1列目は、上記画像の同じ色同士の値を掛けて、さらにbiasの値である、0.3101を加算している。
paddingは0計算となるため、計算式は次のとおりである。

padding×0.2893 + padding×0.1318 + padding×0.1259 + padding×-0.2801 + 0×0.0896 + 0×-0.3314 + padding×0.2742 + 0×0.1006 + 1×0.2215 + 0.3101=0.5316

となる。なお、baisは最後に1回足すだけである。
あとは、赤枠を横に1つストライドして同じ計算をしていく。
と、ここまではよい。

では、入力画像が3チャンネルの場合は、どのように計算するか、である。

Page 7 of 29.

前のページ 次のページ



[添付ファイル]


お問い合わせ

プロフィール

マッスル

自己紹介

本サイトの作成者。
趣味:プログラム/水耕栽培/仮想通貨/激辛好き
プログラムは趣味と勉強を兼ねて、のんびり本サイトを作っています。
フレームワークはdjango。
仮想通貨はNEMが好き。
水耕栽培は激辛好きが高じて、キャロライナ・リーパーの栽培にチャレンジ中。

サイト/ブログ

https://www.osumoi-stdio.com/pyarticle/

ツイッター

@darkimpact0626