在Pytorch中实现im2col操作 Implementing im2col in Pytorch

来源:互联网 发布:淘宝海关拍卖网 编辑:程序博客网 时间:2024/06/06 16:33

Pytorch中可以用torch.unfold, torch.cattorch.transpose的组合实现im2col操作.

TAKE AWAY:

stride = (1, 1)kernel_size = (3, 3)x = torch.arange(0, 25).resize_(5, 5)y = torch.cat(torch.cat(x.unfold(0, kernel_size[0], stride[0]).unfold(1, kernel_size[1], stride[1]).transpose(0, 2), dim=2).transpose(0, 1), dim=0)

下面以一个简单小矩阵举例详细说明单通道im2col操作:

x = torch.arange(0, 25).resize_(5, 5)print(x)  0   1   2   3   4  5   6   7   8   9 10  11  12  13  14 15  16  17  18  19 20  21  22  23  24[torch.FloatTensor of size 5x5]

定义卷积核大小和步长

kernel_size = (3, 3)stride = (1, 1)

首先使用unfold将其切片成小矩阵, 先横着切:

x = x.unfold(0, kernel, 1)print(x)(0 ,.,.) =    0   5  10   1   6  11   2   7  12   3   8  13   4   9  14(1 ,.,.) =    5  10  15   6  11  16   7  12  17   8  13  18   9  14  19(2 ,.,.) =   10  15  20  11  16  21  12  17  22  13  18  23  14  19  24[torch.FloatTensor of size 3x5x3]

再竖着切:

x = x.unfold(1, kernel_size[1], stride[1])print(x)(0 ,0 ,.,.) =    0   1   2   5   6   7  10  11  12(0 ,1 ,.,.) =    1   2   3   6   7   8  11  12  13(0 ,2 ,.,.) =    2   3   4   7   8   9  12  13  14(1 ,0 ,.,.) =    5   6   7  10  11  12  15  16  17(1 ,1 ,.,.) =    6   7   8  11  12  13  16  17  18(1 ,2 ,.,.) =    7   8   9  12  13  14  17  18  19(2 ,0 ,.,.) =   10  11  12  15  16  17  20  21  22(2 ,1 ,.,.) =   11  12  13  16  17  18  21  22  23(2 ,2 ,.,.) =   12  13  14  17  18  19  22  23  24[torch.FloatTensor of size 3x3x3x3]

这里要注意, 因为接下来要使用torch.cat做拼接, 但是因为cat操作的一些特点, 需要先用transpose对维度顺序做一下调整, 注意在我这个例子里维度都是3所以可能看不出来, 可以自己做实验试一下维度不相同的情况:

x = x.transpose(0, 2)(0 ,0 ,.,.) =    0   1   2   5   6   7  10  11  12(0 ,1 ,.,.) =    1   2   3   6   7   8  11  12  13(0 ,2 ,.,.) =    2   3   4   7   8   9  12  13  14(1 ,0 ,.,.) =    5   6   7  10  11  12  15  16  17(1 ,1 ,.,.) =    6   7   8  11  12  13  16  17  18(1 ,2 ,.,.) =    7   8   9  12  13  14  17  18  19(2 ,0 ,.,.) =   10  11  12  15  16  17  20  21  22(2 ,1 ,.,.) =   11  12  13  16  17  18  21  22  23(2 ,2 ,.,.) =   12  13  14  17  18  19  22  23  24[torch.FloatTensor of size 3x3x3x3]

然后用cat拼接一下:

x = torch.cat(x, dim=2)print(x)(0 ,.,.) =    0   1   2   5   6   7  10  11  12   5   6   7  10  11  12  15  16  17  10  11  12  15  16  17  20  21  22(1 ,.,.) =    1   2   3   6   7   8  11  12  13   6   7   8  11  12  13  16  17  18  11  12  13  16  17  18  21  22  23(2 ,.,.) =    2   3   4   7   8   9  12  13  14   7   8   9  12  13  14  17  18  19  12  13  14  17  18  19  22  23  24[torch.FloatTensor of size 3x3x9]

这时再用transpose先转置一下:

x = x.transpose(0, 1)print(x)(0 ,.,.) =    0   1   2   5   6   7  10  11  12   1   2   3   6   7   8  11  12  13   2   3   4   7   8   9  12  13  14(1 ,.,.) =    5   6   7  10  11  12  15  16  17   6   7   8  11  12  13  16  17  18   7   8   9  12  13  14  17  18  19(2 ,.,.) =   10  11  12  15  16  17  20  21  22  11  12  13  16  17  18  21  22  23  12  13  14  17  18  19  22  23  24[torch.FloatTensor of size 3x3x9]

最后cat一次就完成啦:

x = torch.cat(x, dim=2)print(x)    0     1     2     5     6     7    10    11    12    1     2     3     6     7     8    11    12    13    2     3     4     7     8     9    12    13    14    5     6     7    10    11    12    15    16    17    6     7     8    11    12    13    16    17    18    7     8     9    12    13    14    17    18    19   10    11    12    15    16    17    20    21    22   11    12    13    16    17    18    21    22    23   12    13    14    17    18    19    22    23    24[torch.FloatTensor of size 9x9]
原创粉丝点击