caffe_score.py

来源:互联网 发布:文明的起源知乎 编辑:程序博客网 时间:2024/05/16 07:26
from __future__ import divisionimport caffeimport numpy as npimport osimport sysfrom datetime import datetimefrom PIL import Imagedef fast_hist(a, b, n):    k = (a >= 0) & (a < n)    return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)def compute_hist(net, save_dir, dataset, layer='score', gt='label'):    n_cl = net.blobs[layer].channels    if save_dir:        os.mkdir(save_dir)    hist = np.zeros((n_cl, n_cl))    loss = 0    for idx in dataset:        net.forward()        hist += fast_hist(net.blobs[gt].data[0, 0].flatten(),                                net.blobs[layer].data[0].argmax(0).flatten(),                                n_cl)        if save_dir:            im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P')            im.save(os.path.join(save_dir, idx + '.png'))        # compute the loss as well        loss += net.blobs['loss'].data.flat[0]    return hist, loss / len(dataset)def seg_tests(solver, save_format, dataset, layer='score', gt='label'):    print '>>>', datetime.now(), 'Begin seg tests'    solver.test_nets[0].share_with(solver.net)    do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)def do_seg_tests(net, iter, save_format, dataset, layer='score', gt='label'):    n_cl = net.blobs[layer].channels    if save_format:        save_format = save_format.format(iter)    hist, loss = compute_hist(net, save_format, dataset, layer, gt)    # mean loss    print '>>>', datetime.now(), 'Iteration', iter, 'loss', loss    # overall accuracy    acc = np.diag(hist).sum() / hist.sum()    print '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc    # per-class accuracy    acc = np.diag(hist) / hist.sum(1)    print '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc)    # per-class IU    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))    print '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu)    freq = hist.sum(1) / hist.sum()    print '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \            (freq[freq > 0] * iu[freq > 0]).sum()    return hist
原创粉丝点击