在Pytorch中实现im2col操作 Implementing im2col in Pytorch
来源:互联网 发布:淘宝海关拍卖网 编辑:程序博客网 时间:2024/06/06 16:33
在Pytorch中可以用torch.unfold, torch.cat和torch.transpose的组合实现im2col操作.
TAKE AWAY:
stride = (1, 1)kernel_size = (3, 3)x = torch.arange(0, 25).resize_(5, 5)y = torch.cat(torch.cat(x.unfold(0, kernel_size[0], stride[0]).unfold(1, kernel_size[1], stride[1]).transpose(0, 2), dim=2).transpose(0, 1), dim=0)
下面以一个简单小矩阵举例详细说明单通道im2col操作:
x = torch.arange(0, 25).resize_(5, 5)print(x) 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24[torch.FloatTensor of size 5x5]
定义卷积核大小和步长
kernel_size = (3, 3)stride = (1, 1)
首先使用unfold将其切片成小矩阵, 先横着切:
x = x.unfold(0, kernel, 1)print(x)(0 ,.,.) = 0 5 10 1 6 11 2 7 12 3 8 13 4 9 14(1 ,.,.) = 5 10 15 6 11 16 7 12 17 8 13 18 9 14 19(2 ,.,.) = 10 15 20 11 16 21 12 17 22 13 18 23 14 19 24[torch.FloatTensor of size 3x5x3]
再竖着切:
x = x.unfold(1, kernel_size[1], stride[1])print(x)(0 ,0 ,.,.) = 0 1 2 5 6 7 10 11 12(0 ,1 ,.,.) = 1 2 3 6 7 8 11 12 13(0 ,2 ,.,.) = 2 3 4 7 8 9 12 13 14(1 ,0 ,.,.) = 5 6 7 10 11 12 15 16 17(1 ,1 ,.,.) = 6 7 8 11 12 13 16 17 18(1 ,2 ,.,.) = 7 8 9 12 13 14 17 18 19(2 ,0 ,.,.) = 10 11 12 15 16 17 20 21 22(2 ,1 ,.,.) = 11 12 13 16 17 18 21 22 23(2 ,2 ,.,.) = 12 13 14 17 18 19 22 23 24[torch.FloatTensor of size 3x3x3x3]
这里要注意, 因为接下来要使用torch.cat做拼接, 但是因为cat操作的一些特点, 需要先用transpose对维度顺序做一下调整, 注意在我这个例子里维度都是3所以可能看不出来, 可以自己做实验试一下维度不相同的情况:
x = x.transpose(0, 2)(0 ,0 ,.,.) = 0 1 2 5 6 7 10 11 12(0 ,1 ,.,.) = 1 2 3 6 7 8 11 12 13(0 ,2 ,.,.) = 2 3 4 7 8 9 12 13 14(1 ,0 ,.,.) = 5 6 7 10 11 12 15 16 17(1 ,1 ,.,.) = 6 7 8 11 12 13 16 17 18(1 ,2 ,.,.) = 7 8 9 12 13 14 17 18 19(2 ,0 ,.,.) = 10 11 12 15 16 17 20 21 22(2 ,1 ,.,.) = 11 12 13 16 17 18 21 22 23(2 ,2 ,.,.) = 12 13 14 17 18 19 22 23 24[torch.FloatTensor of size 3x3x3x3]
然后用cat拼接一下:
x = torch.cat(x, dim=2)print(x)(0 ,.,.) = 0 1 2 5 6 7 10 11 12 5 6 7 10 11 12 15 16 17 10 11 12 15 16 17 20 21 22(1 ,.,.) = 1 2 3 6 7 8 11 12 13 6 7 8 11 12 13 16 17 18 11 12 13 16 17 18 21 22 23(2 ,.,.) = 2 3 4 7 8 9 12 13 14 7 8 9 12 13 14 17 18 19 12 13 14 17 18 19 22 23 24[torch.FloatTensor of size 3x3x9]
这时再用transpose先转置一下:
x = x.transpose(0, 1)print(x)(0 ,.,.) = 0 1 2 5 6 7 10 11 12 1 2 3 6 7 8 11 12 13 2 3 4 7 8 9 12 13 14(1 ,.,.) = 5 6 7 10 11 12 15 16 17 6 7 8 11 12 13 16 17 18 7 8 9 12 13 14 17 18 19(2 ,.,.) = 10 11 12 15 16 17 20 21 22 11 12 13 16 17 18 21 22 23 12 13 14 17 18 19 22 23 24[torch.FloatTensor of size 3x3x9]
最后cat一次就完成啦:
x = torch.cat(x, dim=2)print(x) 0 1 2 5 6 7 10 11 12 1 2 3 6 7 8 11 12 13 2 3 4 7 8 9 12 13 14 5 6 7 10 11 12 15 16 17 6 7 8 11 12 13 16 17 18 7 8 9 12 13 14 17 18 19 10 11 12 15 16 17 20 21 22 11 12 13 16 17 18 21 22 23 12 13 14 17 18 19 22 23 24[torch.FloatTensor of size 9x9]
阅读全文
0 0
- 在Pytorch中实现im2col操作 Implementing im2col in Pytorch
- im2col的原理和实现
- darknet中im2col代码分析
- 学习:im2col
- im2col算法
- PyTorch
- PyTorch
- PyTorch
- pytorch
- pytorch
- Pytorch
- Caffe中卷积的实现细节(涉及到BaseConvolutionLayer、ConvolutionLayer、im2col等)
- Caffe中卷积的实现细节(涉及到BaseConvolutionLayer、ConvolutionLayer、im2col等)
- Python中如何实现im2col和col2im函数(sliding类型)
- 【pytorch源码赏析】Dataset in pytorch
- 【pytorch】图像基本操作
- Matlab 之 im2col
- im2col函数的用法
- mysql集群的使用与简单测试
- Golang中net/http包源码分析与解释
- 使用百度地图API将输入地址转化成坐标
- 独热编码通俗理解和实例
- compileSdkVersion 'android-24' requires JDK 1.8 or later to compile
- 在Pytorch中实现im2col操作 Implementing im2col in Pytorch
- 测试
- linux配置samba
- OpenGL之gluPerspective浅析
- tomcat版本打印console问题引起的锁问题
- git操作流程图(简洁)
- 机器学习中的End-to-End到底是怎么回事?
- 超级终端调用短信猫发送短信说明
- windows下mysql免安装版教程