Python Intro - Numpy array shape

来源:互联网 发布:java线程第三版源码 编辑:程序博客网 时间:2024/05/16 14:41



#!/usr/local/bin/python3

try:
    import numpy as np
except ImportError:
    print("numpy is not installed")

sizes = [37, 5, 6];

#biases = np.array([ 1.0, 3.0, 5.0, 6.0, 7.0]);
#weights = np.array([ 11.0, 33.0, 45.0, 26.0, 17.0, 88.0]);

biases = [np.random.randn(y, 1) for y in sizes[1:]];
weights = [np.random.randn(y, x) for x, y in zip(sizes[:-1], sizes[1:])]

nabla_b = [np.zeros(b.shape) for b in biases]
nabla_w = [np.zeros(w.shape) for w in weights]

#print("sizes = ", sizes);
#print("biases = ", biases);
#print("weights = ", weights);
#print("nabla_b = ", nabla_b);
#print("nabla_w = ", nabla_w);

print("\n\n\n");
print("===============================================================");
print("\n\n\n");
for b, w in zip(biases, weights):
    print(b.shape);
    t1=b.shape;
    print("b = ", b);
    print("\n");
    print(w.shape);
    print("w = ", w);
    print("\n");
    #a=[np.random.randn(w.shape[1], 1)];
    a=[np.random.randn(w.shape[1], 1)][0];

    print("a = ", a);
    print("\n");
    t = np.dot(w, a) + b;
    print("result = ", t);
    print("-------------------------------------------------------");
    print("\n");

print("\n\n\n");
print("===============================================================");


a=[np.random.randn(w.shape[1], 1)]; 将导致错误的结果,

而 a=[np.random.randn(w.shape[1], 1)][0];将产生正确的结果。

另外 shape 是 Numpy array 的一个属性。。。


错误的结果如下:

===============================================================



(5, 1)
b =  [[-0.80234381]
 [-0.81584549]
 [ 0.19263824]
 [ 0.23136913]
 [-0.35369136]]


(5, 37)
w =  [[-1.22492256 -0.05153998 -0.2416596   0.50049006 -1.49660207  0.85092651
   0.6544673  -0.3754299  -1.80912589  0.37908488  0.42568476  1.49655684
   0.47841895 -1.22748695 -1.44565474  1.59204204  0.54425796 -0.99120238
  -0.56170821  0.70737275 -1.00461932  1.3174263   0.11169217  0.06186711
  -0.85344359  1.16538531  0.41580291  0.01258008  0.20344186  0.41829162
   0.20227069  0.97059174 -0.05373197 -2.19291387 -0.4340676   1.85237748
   0.27900507]
 [-1.45168005 -1.14344156 -0.28568769  0.25735952 -1.26321918 -0.27439448
   0.70492131 -0.90084135 -0.34983977 -0.37380968 -0.83555351 -0.8887116
   0.13211296 -1.96430847  0.79747858  0.50058352 -0.37602784 -0.50755738
   1.14404825 -0.73626532  0.98610876  0.21031306 -1.51556671 -1.34774195
  -0.22559543 -1.4108586   0.62301475  2.17360229 -0.62072268  1.36621206
  -0.33187264 -1.72646447  0.17478133 -0.19081233 -0.52705278  0.47450632
  -1.07083316]
 [-0.11320502  0.73540301  0.21444072 -0.12534763  1.66133067 -0.88573318
   0.19579348 -1.14278562  1.06703606 -0.87959109  0.05832806 -2.84410938
   0.7067185   0.15948225  1.50135313  2.05782267  1.0439752  -0.05334146
  -2.39704177 -0.03350789 -1.06778715  0.10799586 -0.49407107 -0.06985971
  -1.59455816  0.4528298  -0.42533419  0.46746003 -0.20312981 -0.09228466
  -0.43995046  0.51760939  0.24436896 -0.39869921 -0.86469663 -2.68340844
   0.2734433 ]
 [ 0.42368235  0.82065636  0.33775918 -0.6113859  -1.66232777  0.10907928
   0.61912058  0.47998729 -1.11836201 -0.40173418 -0.30956985 -0.52852956
  -0.98234968 -1.19866915 -0.19059407 -0.52731314  0.51324251 -0.12440913
   0.05918655 -0.77988089 -0.4261943   0.44117496 -0.269059   -1.19496792
  -0.4348697  -1.11977948 -0.24246792  2.04652802 -0.69826327 -1.84210738
  -1.97253499 -0.74212964  0.38386701  0.85001423  0.80454352  0.76351279
   0.40977939]
 [-0.25813088  1.81696984 -0.23022431 -0.30215019  1.05691976  0.82763261
   0.94810037 -1.82938546 -0.6754055  -0.20740134 -0.19331754  1.14308525
  -1.64443817 -0.92892617  0.09999019  0.31131367 -1.38775837  0.12595169
   0.02417245  1.72788269  0.47844719  0.35279001 -0.23732005  0.66667847
   0.27974457  0.7541109   1.49282702 -1.59411123  0.24738603 -0.33174431
  -1.11101875 -0.77369412  0.21519189  0.70745882 -0.93253936 -0.33803257
  -0.18741942]]


a =  [array([[ 0.43080953],
       [ 0.6271243 ],
       [ 0.25769139],
       [ 1.15770288],
       [ 0.88505929],
       [-1.18651882],
       [-2.0328503 ],
       [-0.33605496],
       [-0.7184775 ],
       [ 0.2481995 ],
       [ 0.791939  ],
       [ 0.82346004],
       [-0.95404354],
       [-0.80400157],
       [-0.85874463],
       [-0.19380442],
       [-1.4377476 ],
       [ 0.74277695],
       [ 2.1718078 ],
       [-1.42753621],
       [-2.1769708 ],
       [ 0.6066664 ],
       [ 0.89533354],
       [ 0.40967936],
       [ 1.28491686],
       [-0.55703151],
       [-0.25216725],
       [-0.46927728],
       [-1.11216852],
       [ 0.21308665],
       [-1.23149161],
       [ 0.98024705],
       [ 0.25706576],
       [-0.61663706],
       [-1.2014525 ],
       [ 0.80797824],
       [ 0.41965466]])]


result =  [[[  1.58854826]
  [  1.57504658]
  [  2.58353031]
  [  2.6222612 ]
  [  2.03720071]]

 [[ -4.88910208]
  [ -4.90260375]
  [ -3.89412003]
  [ -3.85538914]
  [ -4.44044963]]

 [[-10.38698865]
  [-10.40049033]
  [ -9.3920066 ]
  [ -9.35327571]
  [ -9.9383362 ]]

 [[ -0.03147276]
  [ -0.04497444]
  [  0.96350929]
  [  1.00224018]
  [  0.41717969]]

 [[  1.71963699]
  [  1.70613532]
  [  2.71461904]
  [  2.75334993]
  [  2.16828944]]]
-------------------------------------------------------


(6, 1)
b =  [[-0.01942626]
 [-0.31193881]
 [ 0.43504691]
 [ 1.84491166]
 [-0.80456819]
 [ 0.03021581]]


(6, 5)
w =  [[ -1.13538691e+00  -1.17284504e+00  -6.81504250e-01   2.59648757e-02
    3.42289781e-01]
 [  8.39006451e-01   3.36292700e-01   1.58288576e+00  -1.25206326e-01
    1.37358527e-01]
 [ -4.62542151e-01  -9.32414390e-01   1.05534508e-03   1.99062893e+00
    5.84544636e-01]
 [  4.63710049e-01   2.59477781e-01  -6.14149339e-02   5.65044304e-01
   -6.04752454e-01]
 [  3.73561048e-01  -1.14715510e-01  -3.71172780e-01  -6.14595047e-01
   -8.42810547e-01]
 [  6.79003560e-01   2.42944760e-01   1.59213349e+00  -2.18025362e-01
    1.55950700e+00]]


a =  [array([[-0.40012735],
       [-0.42808614],
       [ 1.13635464],
       [ 0.78782013],
       [ 1.44406382]])]


result =  [[[ 0.67726523]
  [ 0.38475268]
  [ 1.1317384 ]
  [ 2.54160315]
  [-0.1078767 ]
  [ 0.7269073 ]]

 [[ 1.39933608]
  [ 1.10682352]
  [ 1.85380925]
  [ 3.26367399]
  [ 0.61419414]
  [ 1.44897814]]

 [[ 2.97837974]
  [ 2.68586718]
  [ 3.43285291]
  [ 4.84271765]
  [ 2.1932378 ]
  [ 3.0280218 ]]

 [[-0.81398518]
  [-1.10649774]
  [-0.35951201]
  [ 1.05035273]
  [-1.59912712]
  [-0.76434312]]

 [[-2.24283661]
  [-2.53534916]
  [-1.78836344]
  [-0.37849869]
  [-3.02797854]
  [-2.19319454]]

 [[ 3.49437571]
  [ 3.20186316]
  [ 3.94884888]
  [ 5.35871363]
  [ 2.70923378]
  [ 3.54401778]]]
-------------------------------------------------------




===============================================================


正确的结果如下:



===============================================================



(5, 1)
b =  [[-0.262492  ]
 [-0.20629397]
 [-0.01950833]
 [ 0.74556297]
 [-0.59764296]]


(5, 37)
w =  [[-0.9970828   0.73065033 -0.4922562   2.42612649 -0.34932448  0.47006178
   1.54562521  0.60566196 -0.02559831 -0.87583074  0.59756882 -0.7147512
  -0.27414287  0.8355061  -0.36815386 -0.02422284 -0.12012678  1.31111424
  -0.29971747 -0.53969772  1.44673392  1.78797172 -0.15529384 -0.27336318
  -1.77369675 -0.85413537  1.35460962 -0.14779111 -0.78557714 -0.02437009
  -0.7033722  -0.54030338 -0.72616893 -0.16379103 -0.29007255 -0.25834042
   1.36683428]
 [-0.56872961  0.60951407 -0.35310149  0.52013285  0.61663699 -0.30961368
  -0.20963367  0.71705772 -0.5316324   0.25423788 -1.68829098  0.10111508
  -0.70798256  0.50275926 -0.49245537 -0.95913379  0.5142013   0.58778145
  -0.94998304 -0.76775667 -0.34393914 -0.38942604 -1.4718932  -1.03127752
   0.14620239 -1.84580584 -1.63070239 -0.77504503  0.73485183  0.12520007
  -0.938779    1.83107633 -0.93583973 -0.62045299  1.23843946  0.14997827
   1.56348389]
 [ 0.33972002 -1.18879186  0.49104202 -0.07079901  1.08109047 -1.3101739
  -0.86728255 -0.64579715  1.10382172 -1.58168163 -0.24854551 -0.90613182
  -1.65663682  0.58143426  0.27627428  0.16452686  0.8565011   0.42513523
  -2.21190155  0.61925565  1.11953894  1.02110787 -0.17919635  1.95764644
   0.58239631  1.63533897 -0.4562275  -0.67159646  0.41412204  0.11597722
   0.52239521 -1.55371469 -0.02140371  0.47340526 -0.60128875 -0.98142806
  -1.18309974]
 [ 0.60442293 -0.11307345  1.60810941 -1.90193352  0.36896755 -1.055618
   1.33453487 -0.0839704   0.21068519  1.5959486   0.4395267   0.8754726
  -0.58788291  0.88672996  1.55037176 -1.5332406   0.57252895  1.54376175
  -0.20691328  0.77768875 -0.8624152   0.4351931   1.23382085  0.34031488
  -0.27569002  0.48206992  0.2324622  -0.18764399  0.18933579  0.43943348
  -0.0666472  -0.0074146   0.44300101 -0.58442105  1.58259841 -1.45337009
  -0.12645283]
 [-1.17699839  1.42546676 -0.06306605 -1.35392344 -0.52848191  1.12807458
   1.12138442  0.95538687 -0.32399031 -1.02436395 -0.77799109  0.5657171
  -0.52596777 -0.74161448 -0.15159113  0.64114591  1.31848493  1.50740195
  -0.36352739  2.26293174  0.00654035 -1.46270206 -0.31493683  0.35115431
  -0.64182191  0.70758048  0.17557849 -1.0962026  -0.55330834 -0.1487515
   0.85021598 -0.98307117  0.97868651 -0.51840716 -1.23989995 -0.20878353
  -0.29166749]]


a =  [[ 0.57490808]
 [ 0.93551352]
 [ 0.48746221]
 [-1.05142943]
 [ 0.17675075]
 [ 0.30799841]
 [-1.85511516]
 [-0.1011902 ]
 [ 0.17074073]
 [ 0.07103635]
 [-0.3394694 ]
 [-1.21016465]
 [ 2.27487407]
 [ 0.1694055 ]
 [ 0.2834597 ]
 [-0.48268838]
 [ 1.69990823]
 [ 0.18432417]
 [ 0.41151327]
 [ 0.08262408]
 [ 0.11720957]
 [-1.70657664]
 [ 0.47772809]
 [-0.47040207]
 [-0.50707268]
 [-0.50235741]
 [ 0.1109001 ]
 [ 0.0418734 ]
 [-0.91857133]
 [-1.07522034]
 [ 0.02980418]
 [-0.83182824]
 [-0.98320158]
 [ 1.87120068]
 [ 0.26556049]
 [-1.26170258]
 [-0.45455707]]


result =  [[-6.04175764]
 [-2.77462728]
 [-1.38691038]
 [ 0.85500108]
 [ 1.63308819]]
-------------------------------------------------------


(6, 1)
b =  [[ 0.81563918]
 [-0.56125996]
 [ 1.64255057]
 [-1.78762905]
 [-0.10637609]
 [-1.30160507]]


(6, 5)
w =  [[-0.41877003  0.12253528  0.71806184 -0.59645841  0.34434439]
 [ 1.84111609 -1.30775289 -1.24070082  0.55306451 -1.94332213]
 [ 0.40351707  0.17099953  0.60908037 -0.25724238 -0.32904466]
 [-0.15303428 -0.01397066 -0.55895095 -1.47137884 -0.52661826]
 [ 0.2487372  -0.88314338 -0.15866652  0.10156308  1.42306267]
 [-0.76701692  1.19854978  0.31583967  1.04686157  0.26428875]]


a =  [[-0.62313353]
 [ 0.83691557]
 [ 0.16848985]
 [ 1.11337741]
 [ 2.09855381]]


result =  [[ 1.35866856]
 [-6.47444192]
 [ 0.65991579]
 [-4.54147481]
 [ 2.0722289 ]
 [ 1.95282584]]
-------------------------------------------------------




===============================================================



0 0