【python代码技巧2】数组索引扩增技巧

来源:互联网 发布:盛势网络剧剧照 编辑:程序博客网 时间:2024/06/06 08:42
import numpy as npimport torcha = np.array([x for x in range(12)]).reshape(3, 4)index = np.array([[0,1] for x in range(5)]).flatten()print "a:\n", a, a.shapeprint "index:\n", index, index.shape
a:[[ 0  1  2  3] [ 4  5  6  7] [ 8  9 10 11]] (3, 4)index:[0 1 0 1 0 1 0 1 0 1] (10,)
print a[index], a[index].shape
[[0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7] [0 1 2 3] [4 5 6 7]] (10, 4)

pytorch同样有这种特性

a = torch.Tensor([x for x in range(12)]).view(3, 4)index = torch.LongTensor([[0,1] for x in range(5)]).view(-1)print "a:\n", aprint "index:\n", index
a:  0   1   2   3  4   5   6   7  8   9  10  11[torch.FloatTensor of size 3x4]index: 0 1 0 1 0 1 0 1 0 1[torch.LongTensor of size 10]
print a[index], a[index].size()
    0     1     2     3    4     5     6     7    0     1     2     3    4     5     6     7    0     1     2     3    4     5     6     7    0     1     2     3    4     5     6     7    0     1     2     3    4     5     6     7[torch.FloatTensor of size 10x4] torch.Size([10, 4])
原创粉丝点击