5 python numpy.expand_dims的用法

来源:互联网 发布:java测试工程师 编辑:程序博客网 时间:2024/05/01 19:06

1 查看help

其实感觉expand_dims(a, axis)就是在axis的那一个轴上把数据加上去,这个数据在axis这个轴的0位置。
例如原本为一维的2个数据,axis=0,则shape变为(1,2),axis=1则shape变为(2,1)
再例如 原本为 (2,3),axis=0,则shape变为(1,2,3),axis=1则shape变为(2,1,3)

help(np.expand_dims)
Help on function expand_dims in module numpy.lib.shape_base:expand_dims(a, axis)    Expand the shape of an array.    Insert a new axis, corresponding to a given position in the array shape.    Parameters    ----------    a : array_like        Input array.    axis : int        Position (amongst axes) where new axis is to be inserted.    Returns    -------    res : ndarray        Output array. The number of dimensions is one greater than that of        the input array.    See Also    --------    doc.indexing, atleast_1d, atleast_2d, atleast_3d    Examples    --------    >>> x = np.array([1,2])    >>> x.shape    (2,)    The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``:    >>> y = np.expand_dims(x, axis=0)    >>> y    array([[1, 2]])    >>> y.shape    (1, 2)    >>> y = np.expand_dims(x, axis=1)  # Equivalent to x[:,newaxis]    >>> y    array([[1],           [2]])    >>> y.shape    (2, 1)    Note that some examples may use ``None`` instead of ``np.newaxis``.  These    are the same objects:    >>> np.newaxis is None    True

2 测试一维的数据

x = np.array([1,2,3])print xprint x.shape
[1 2 3](3,)
y = np.expand_dims(x,axis=0)print yprint "y.shape: ",y.shapeprint "y[0][1]: ",y[0][1]
[[1 2 3]]y.shape:  (1, 3)y[0][1]:  2
y = np.expand_dims(x,axis=1)print yprint "y.shape: ",y.shapeprint "y[1][0]: ",y[1][0]
[[1] [2] [3]]y.shape:  (3, 1)y[1][0]:  2
y = np.expand_dims(x,axis=3)print yprint "y.shape: ",y.shapeprint "y[2][0]: ",y[2][0]
[[1] [2] [3]]y.shape:  (3, 1)y[2][0]:  3

3 测试二维的数据

x = np.array([[1,2,3],[4,5,6]])print xprint x.shape
[[1 2 3] [4 5 6]](2, 3)
y = np.expand_dims(x,axis=0)print yprint "y.shape: ",y.shapeprint "y[0][1]: ",y[0][1]
[[[1 2 3]  [4 5 6]]]y.shape:  (1, 2, 3)y[0][1]:  [4 5 6]
y = np.expand_dims(x,axis=1)print yprint "y.shape: ",y.shapeprint "y[1][0]: ",y[1][0]
[[[1 2 3]] [[4 5 6]]]y.shape:  (2, 1, 3)y[1][0]:  [4 5 6]
y = np.expand_dims(x,axis=3)print yprint "y.shape: ",y.shapeprint "y[2][0]: ",y[2][0]
[[[1]  [2]  [3]] [[4]  [5]  [6]]]y.shape:  (2, 3, 1)y[2][0]: ---------------------------------------------------------------------------IndexError                                Traceback (most recent call last)<ipython-input-16-392d9cded3f4> in <module>()      2 print y      3 print "y.shape: ",y.shape----> 4 print "y[2][0]: ",y[2][0]IndexError: index 2 is out of bounds for axis 0 with size 2
0 0