网络流算法Dinic的Python实现

来源:互联网 发布:mysql字符集设置 编辑:程序博客网 时间:2024/04/27 14:45

在上一篇我们提到了网络流算法Push-relabel,那是90年代提出的算法,算是比较新的,而现在要说的Dinic算法则是由以色列人Dinitz在冷战时期,即60-70年代提出的算法变种而来的,其算法复杂度为O(mn^2)。

Dinic算法主要思想也是基于FF算法的,改进的地方也是减少寻找增广路径的迭代次数。此处Dinitz大师引用了一个非常聪明的数据结构,Layer Network,分层网络,该结构是由BFS tree启发得到的,它跟BFS tree的区别在于,BFS tree只保存到每一层的一条边,这样就导致了利用BFS tree一次只能发现一条增广路径,而分层网络保存了到每一层的所有边,但层内的边不保存。

介绍完数据结构,开始讲算法的步骤了,1)从网络的剩余图中利用BFS宽度优先遍历技术生成分层网络。2)在分层网络中不断调用DFS生成增广路径,直到s不可到达t,这一步体现了Dinic算法贪心的特性。3)max_flow+=这次生成的所有增广路径的flow,重新生成剩余图,转1)。

源代码如下:

采用递归实现BFS和DFS,效率不高。

__author__ = 'xanxus'nodeNum, edgeNum = 0, 0arcs = []class Arc(object):    def __init__(self):        self.src = -1        self.dst = -1        self.cap = -1class Layer(object):    def __init__(self):        self.nodeSet = set()        self.arcList = []s, t = -1, -1with open('demo.dimacs') as f:    for line in f.readlines():        line = line.strip()        if line.startswith('p'):            tokens = line.split(' ')            nodeNum = int(tokens[2])            edgeNum = tokens[3]        if line.startswith('n'):            tokens = line.split(' ')            if tokens[2] == 's':                s = int(tokens[1])            if tokens[2] == 't':                t = int(tokens[1])        if line.startswith('a'):            tokens = line.split(' ')            arc = Arc()            arc.src = int(tokens[1])            arc.dst = int(tokens[2])            arc.cap = int(tokens[3])            arcs.append(arc)nodes = [-1] * nodeNumfor i in range(s, t + 1):    nodes[i - s] = iadjacent_matrix = [[0 for i in range(nodeNum)] for j in range(nodeNum)]for arc in arcs:    adjacent_matrix[arc.src - s][arc.dst - s] = arc.capdef getLayerNetwork(current, ln, augment_set):    if t - s in ln[current].nodeSet:        return    for i in ln[current].nodeSet:        augment_set.add(i)        has_augment = False        for j in range(len(adjacent_matrix)):            if adjacent_matrix[i][j] != 0:                if len(ln) == current + 1:                    ln.append(Layer())                if j not in augment_set and j not in ln[current].nodeSet:                    has_augment = True                    ln[current + 1].nodeSet.add(j)                    arc = Arc()                    arc.src, arc.dst, arc.cap = i, j, adjacent_matrix[i][j]                    ln[current].arcList.append(arc)        if not has_augment and (i != t - s or i != 0):            augment_set.remove(i)            filter(lambda x: x == i, ln[current].nodeSet)            newArcList = []            for arc in ln[current - 1].arcList:                if arc.dst != i:                    newArcList.append(arc)            ln[current - 1].arcList = newArcList    if len(ln) == current + 1:        return    getLayerNetwork(current + 1, ln, augment_set)def get_path(layerNetwork, src, current, path):    for arc in layerNetwork[current].arcList:        if arc.src == src and arc.cap != 0:            path.append(arc)            get_path(layerNetwork, arc.dst, current + 1, path)            returndef find_blocking_flow(layerNetwork):    sum_flow = 0    while (True):        path = []        get_path(layerNetwork, 0, 0, path)        if path[-1].dst != t - s:            break        else:            bottleneck = min([arc.cap for arc in path])            for arc in path:                arc.cap -= bottleneck            sum_flow += bottleneck    return sum_flowmax_flow = 0while (True):    layerNetwork = []    firstLayer = Layer()    firstLayer.nodeSet.add(0)    layerNetwork.append(firstLayer)    augment_set = set()    augment_set.add(0)    getLayerNetwork(0, layerNetwork, augment_set)    if t - s not in layerNetwork[-1].nodeSet:        break    current_flow = find_blocking_flow(layerNetwork)    if current_flow == 0:        break    else:        max_flow += current_flow        # add the backward arcs        for layer in layerNetwork:            for arc in layer.arcList:                adjacent_matrix[arc.dst][arc.src] += adjacent_matrix[arc.src][arc.dst] - arc.cap                adjacent_matrix[arc.src][arc.dst] = arc.capfor arc in arcs:    print 'f %d %d %d' % (arc.src, arc.dst, arc.cap - adjacent_matrix[arc.src - s][arc.dst - s])


0 0
原创粉丝点击