scatter_ 和 one hot
看了很多博客,中國人寫博客有一個特點就是複制來複制去,根本沒有講到重點,好了廢話不多扯,今天講下 scatter_ 函數。
操作一:
import torch # 導入 torch模塊,這裡操作的都是張量數據
src = torch.arange(1, 11).reshape((2, 5)) # 這裡創建一個 2行5列的數據
print(src) # 打印出來
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
上面這個是準備數據,是一個兩行五列的數據。再創建一個索引數據
index = torch.tensor([[0, 1, 2, 0, 2]])
print(index)
tensor([[0, 1, 2, 0, 2]])
在這之前都是很簡單的,相比讀者肯定能看到,無非就是兩個數據,請耐心往下看
result_1 = torch.zeros(3, 5, dtype=src.dtype) # 創建一個3行5列的數據全是0
print(result_1)
tensor([[0, 0, 0, 0 0],
[0, 0 0, 0, 0],
[0, 0, 0, 0, 0]])
解析來就是使用 scatter_函數: 也就是根據相關索引,把result_1的指定位置填充下
result = result_1.scatter_(0, index, src)
這裡是什麼意思呢, 0 表示按列來處理,result_1 是需要被更改的數據,index是索引位置, src數用來填充的數據,舉例子: 如上面描述:
result_1 = tensor([[0, 0, 0, 0 0],
[0, 0 0, 0, 0],
[0, 0, 0, 0, 0]])
index = tensor([[0, 1, 2, 0, 2]])
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
第一個參數 0 表示按列來處理
索引第1個值為0,這表示第1列的第1個數據設置為scr中的第2個數據
ensor([[1, 0, 0, 0 0],
[0, 0 0, 0, 0],
[0, 0, 0, 0, 0]])
索引第2個值為1,這表示第2列的第2個數據設置為scr中的第2個數據
ensor([[1, 0, 0, 0 0],
[0, 2 0, 0, 0],
[0, 0, 0, 0, 0]])
索引第3個值為2,這表示第3列的第3個數據設置為scr中的第3個數據
ensor([[1, 0, 0, 0 0],
[0, 2 0, 0, 0],
[0, 0, 3, 0, 0]])
索引第4個值為0,這表示第4列的第1個數據設置為scr中的第4個數據
ensor([[1, 0, 0, 4 0],
[0, 2 0, 0, 0],
[0, 0, 3, 0, 0]])
索引第5個值為2,這表示第5列的第3個數據設置為scr中的第5個數據
ensor([[1, 0, 0, 4 0],
[0, 2 0, 0, 0],
[0, 0, 3, 0, 5]])
以上就是詳細的計算流程
操作2:
idx = torch.tensor([[0, 1, 2, 3,4]])
last = torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=1, index=idx, value=2)
這裡第一步我相信大家都熟悉,就是創建一個數據而已,這裡我們理解為索引數據
1、torch.zeros(3, 5, dtype=src.dtype). 表示的是創建一個3行5列的數據矩陣,全是0
tensor([[0, 0, 0, 0 0],
[0, 0 0, 0, 0],
[0, 0, 0, 0, 0]])
2、dim=1,表示是按行計算
3、value,表示相應的位置上設置為某個值
idx = torch.tensor([[0, 1, 2, 3,4]])
表示的是第一行的第 0 1 2 3 4 的位置上全是設置為2,也就是
tensor([[2, 2, 2, 2, 2],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
當然,我相信某些人還是一臉懵逼,再繼續往下看
idx = torch.tensor([[0, 1, 2, 3,4],[0,0,0,0,4]])
last = torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=1, index=idx, value=2)
這裡我們看到idx為 torch.tensor([[0, 1, 2, 3,4],[0,0,0,0,4]])
這個idx有兩行,那麼他對應的也是 torch.zeros(3, 5, dtype=src.dtype)中的兩行數據,
[0, 1, 2, 3,4] 表示的是第一行的第 0 1 2 3 4 的位置上全是設置為2
[0,0,0,0,4]]表示的是第二行的第 0 、4 的位置上設置為2,其他地方不變
因此整體數據變成了
tensor([[2, 2, 2, 2, 2],
[2, 0, 0, 0, 2],
[0, 0, 0, 0, 0]])
好了,這個函數介紹到此為止,希望能幫到大家
,更多精彩资讯请关注tft每日頭條,我们将持续为您更新最新资讯!