numpy的axis的学习

来源:互联网 发布:iphone看片软件 编辑:程序博客网 时间:2024/05/29 13:50
import numpy as npa = np.array([[[1, 2, 4], [1, 2, 4]], [[3, 2, 1], [1, 2, 4]], [[3, 2, 1], [1, 2, 4]]])print(a.shape)b = np.max(a, axis=0)print(b.shape)print(b)c = np.max(a, axis=1)print(c.shape)print(c)d = np.max(a, axis=2)print(d.shape)print(d)

a的形状是(3, 2, 3)
b的形状是

312244(b)

我们发现max和axis是一种降维的方法,变为(2, 3),我们可以理解为axis=0是在一个batch进行的。
c的形状是
133222444(c)

d的形状是
433444(d)

也就是说,axis是用来选定在哪一个维度进行计算的,axis从小到大,范围也越来越小,是一种降维的方法。