chain matrix product

来源:互联网 发布:science online数据库 编辑:程序博客网 时间:2024/06/08 18:53
# given matrix A1, A2, ..., An# the sizes of them are m0*m1, m1*m2, ..., m(n-1) * mn# in general, a matrix with size m*n  multipling a matrix with size of n*p will need mnp operations.# let cost(i,j) be the minimum cost of prodcut Ai to Aj# the size of product Ai to Aj is (i-1, j) # then cost(i,j) = min[cost(i,k) + cost(k+1, j) + m(i-1)*m(k)*m(j)]# for any i, cost(i,i) = 0def chain_matrix_product(shapes):    """    shapes : [(6,3), (3,1), (1,3), (3,8)]    """    cost_cache = {}    # initial value    for i in range(len(shapes)):        for j in range(i, len(shapes)):            cost_cache[(i,j)] = (0, [])        for s in range(1, len(shapes)): # start from upper diagnal        for i in range(len(shapes) - s):            j = i + s            min_val = float("inf")            min_k = None            for k in range(i,j):                val = cost_cache[(i,k)][0]+cost_cache[(k+1,j)][0] + shapes[i][0]*shapes[k][1]*shapes[j][1]                if val < min_val:                    min_val = val                    min_k = k                            partitions = [(i,min_k,j)] + cost_cache[(i,min_k)][1] + cost_cache[(min_k+1,j)][1]            cost_cache[(i,j)] = (min_val, partitions)    return cost_cache[(0, len(shapes)-1)]if __name__ == "__main__":    matrix = [(6,3), (3,1), (1,3), (3,8)]    print(chain_matrix_product(matrix))

complexity O(n^3)

result (90, [(0,1,3),(0,0,1),(2,2,3)]

(i,k,j) partition is (i->k) (k+1 -> j)

[[(6,3)]*[(3,1)]]*[[(1,3)]*[(3,8)]]

matrix[0-0] matrix[1-1] => matrix[0-1]

                                                                \

                                                                 => matrix[0-1] * matrix[2-3]

                                                                /

matrix[2,2] matrix[3-3]=> matrix[2-3] 

原创粉丝点击