tensorflow源码解读之gen_array_ops.py文件的space_to_batch_nd方法

来源:互联网 发布:c语言定时器函数 编辑:程序博客网 时间:2024/06/07 13:15

这个文件是编译时生成的,在array_ops.py文件的space_to_batch方法调用到。

这个操作是atrous_conv2d为了给gen_nn_ops.conv2d传参数之前的参数调整。
而gen_nn_ops.conv2d的input参数一直是[batch, in_height, in_width, in_channels]
例如这里的[1,5,5,1]经过space_to_batch_nd方法之后变成[4,3,3,1]

(这里array_ops.space_to_batch传入的paddings参数是[[0,1],[0,1]],感觉就是(5+1)*(5+1)==4*3*3的样子,
又例如space_to_batch_nd的input是[2, 2, 4, 1]而paddings是[[0, 0], [2, 0]],output是[8, 1, 3, 1]也即2*2*(2+4)==8*3),

可以看出就是batch变大了而in_height, in_width变小了,也就是space_to_batch的意思,

继续回到这里的例子,gen_nn_ops.conv2d的Filter参数是这样的 [kernel_height, kernel_width, output_depth, input_depth],这时的Filter是[3,3,1,1],和上面的[4,3,3,1]是可以一起传入gen_nn_ops.conv2d的,经过gen_nn_ops.conv2d方法之后返回的value是[4,1,1,1],最后经过array_ops.batch_to_space方法返回[1,1,1,1]

0 0