Deep

Deepラーニングのメモです

1172 views

pytorchのgatherが謎

pytorchでgatherを使っているサンプルコードがあったのだが、gatherの意味が分からなかったので調べた。
gatherという英単語の意味は「集める」という意味。

まずはサンプルコード。

if __name__ == '__main__':
    input = torch.tensor([
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ])

    indices = torch.tensor([
        [2, 1, 0],
        [0, 1, 2],
        [2, 0, 1]])

    # gatherは集めるの意味
    result1 = torch.gather(input=input, dim=0, index=indices)
    result2 = torch.gather(input=input, dim=1, index=indices)
    print(result1)
    print(result2)

以下のinputの2次元配列と

    input = torch.tensor([
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ])

以下のindiciesの二つの2次元配列を用意する。

    indices = torch.tensor([
        [2, 1, 0],
        [0, 1, 2],
        [2, 0, 1]])

これをgatherにかける。

    result1 = torch.gather(input=input, dim=0, index=indices)
    result2 = torch.gather(input=input, dim=1, index=indices)

すると、答えは以下。

tensor([[7, 5, 3],
        [1, 5, 9],
        [7, 2, 6]])
tensor([[3, 2, 1],
        [4, 5, 6],
        [9, 7, 8]])

何をやっているのか。

dimが0の場合

indicesの各行の値は、inputの各列のindexを表している。

indicesの0行目を見ると、2,1,0である。
indicesの赤枠の意味は0列目の2番目の値を見ろ、という意味。
indicesの青枠の意味は1列目の1番目の値を見ろ、という意味。
indicesの緑枠の意味は2列目の0番目の値を見ろ、という意味。
つまり、inputに記している赤字の部分が該当し、[7,5,3]が出力される。
図で表すと、以下の関係となる。

次にindicesの1行目を見ると、0,1,2である。
indicesの赤枠の意味は0列目の0番目の値を見ろ、という意味。
indicesの青枠の意味は1列目の1番目の値を見ろ、という意味。
indicesの緑枠の意味は2列目の2番目の値を見ろ、という意味。
つまり、inputに記している赤字の部分が該当し、[1,5,9]が出力される。
図で表すと、以下の関係となる。

最後にindicesの2行目を見ると、2,0,1である。
indicesの赤枠の意味は0列目の2番目の値を見ろ、という意味。
indicesの青枠の意味は1列目の0番目の値を見ろ、という意味。
indicesの緑枠の意味は2列目の1番目の値を見ろ、という意味。
つまり、inputに記している赤字の部分が該当し、[7,2,6]が出力される。
図で表すと、以下の関係となる。

dimが1の場合

dim1のほうが動きが分かりやすい。

まず、indicesの0行目を見ると、[2, 1, 0]である。
indicesの赤枠の意味は0行目の2番目の値を見ろ、という意味。
indicesの青枠の意味は0行目の1番目の値を見ろ、という意味。
indicesの緑枠の意味は0行目の0番目の値を見ろ、という意味。
つまり、[3,1,1]となり、indicesの各色の枠がinputの各色の値を表している。
図で表すと以下の関係になる。

次に、indicesの1行目を見ると、[0, 1, 2]である。
indicesの赤枠の意味は1行目の0番目の値を見ろ、という意味。
indicesの青枠の意味は1行目の1番目の値を見ろ、という意味。
indicesの緑枠の意味は1行目の2番目の値を見ろ、という意味。
つまり、[4,5,6]となり、indicesの各色の枠がinputの各色の値を表している。
図で表すと以下の関係になる。

最後に、indicesの2行目を見ると、[2, 0, 1]である。
indicesの赤枠の意味は2行目の2番目の値を見ろ、という意味。
indicesの青枠の意味は2行目の0番目の値を見ろ、という意味。
indicesの緑枠の意味は2行目の1番目の値を見ろ、という意味。
つまり、[9,7,8]となり、indicesの各色の枠がinputの各色の値を表している。
図で表すと以下の関係になる。

お役に立ったようでしたら、左上の★ボタンのクリックをお願い致します。

Page 17 of 29.

前のページ 次のページ



[添付ファイル]


お問い合わせ

プロフィール

マッスル

自己紹介

本サイトの作成者。
趣味:プログラム/水耕栽培/仮想通貨/激辛好き
プログラムは趣味と勉強を兼ねて、のんびり本サイトを作っています。
フレームワークはdjango。
仮想通貨はNEMが好き。
水耕栽培は激辛好きが高じて、キャロライナ・リーパーの栽培にチャレンジ中。

サイト/ブログ

https://www.osumoi-stdio.com/pyarticle/

ツイッター

@darkimpact0626