python pde adi(抛物型差分(二维—ADI格式))

来源:互联网 发布:淘宝客没权重吗 编辑:程序博客网 时间:2024/06/03 17:26
#coding:utf-8from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltimport numpy as npimport timedef createMatrix( m, n):A = np.zeros( (n + 2,m + 2))Up = np.ones( (m+2,1)) * 100Down = np.ones((m+2, 1)) * 0Lf = np.ones((1, n + 2)) * 75Rt = np.ones((1, n + 2) )* 50A[0,:] = Up.ravel()A[n+1,:] = Down.ravel()A[:,0] = Lf.ravel()A[:, m +1] = Rt.ravel()return Adef oneIter(A,  r_lf, r_rt):a_size = A.shapem = a_size[1] - 2n = a_size[0] - 2#create init ImpMatrix M and bM = np.diag( np.ones((1,m)).ravel() * ( 1 + r_lf))M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), 1)M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), -1)B = A.copy()for j in range(1, n + 1 ):b = np.zeros((m,1))rowA = A[j,:]b[0] = b[0] + rowA[0] * r_lf / 2b[m-1] = b[m-1] + rowA[m-1] * r_lf /2for i in range(1, m+1):colA = A[j-1:j+1+1,i]b[i-1] = b[i-1] + r_rt / 2 * colA[0] + ( 1 - r_rt) * colA[1] + r_rt / 2 * colA[2]B[j,1:m+1] = np.linalg.solve(M, b).ravel()return Bdef computeA(m, n , rx, ry, iter):A = createMatrix(m,n)print 'total iter=%s' % (iter)for i in range(1, iter):print 'iter num=%s' % (i)A = oneIter(A, rx,ry)B = oneIter(np.transpose(A), ry, rx)A = np.transpose(B)return Adef computeOneIter(A, m, n , rx, ry):A = oneIter(A, rx,ry)B = oneIter(np.transpose(A), ry, rx)A = np.transpose(B)return Adef getStart():X_INTERVAL = [0,20]Y_INTERVAL = [0,30]T = [0,10]deltax = 0.5deltay = 0.3tao = 1.0 / 3 * min(deltax, deltay) * min(deltax, deltay)m = (X_INTERVAL[1] - X_INTERVAL[0]) / deltax - 1n = (Y_INTERVAL[1] - Y_INTERVAL[0]) / deltay - 1m = int(m)n = int(n)print 'm=%s,n=%s' % (m,n)x = np.linspace(X_INTERVAL[0], X_INTERVAL[1], m)y = np.linspace(Y_INTERVAL[0], Y_INTERVAL[1], n)#A = computeA(m,n,tao/deltax/deltax, tao/deltay/deltay, int((T[1] - T[0])/tao))#animationfig = plt.figure()ax = fig.add_subplot(111, projection='3d')X = xY = yX, Y = np.meshgrid(X, Y)wframe = Noneiter = int((T[1] - T[0])/tao)A = createMatrix(m-2,n-2)for i in range(iter):A = computeOneIter(A,m,n,tao/deltax/deltax, tao/deltay/deltay)if wframe:ax.collections.remove(wframe)wframe = ax.plot_wireframe(X, Y, A, rstride=2, cstride=2)plt.pause(0.01)print 'iter=',im = A.shape[0]n = A.shape[1]return A,x,yif __name__ == '__main__':getStart()

0 0
原创粉丝点击