caffe_score.py
来源:互联网 发布:文明的起源知乎 编辑:程序博客网 时间:2024/05/16 07:26
from __future__ import divisionimport caffeimport numpy as npimport osimport sysfrom datetime import datetimefrom PIL import Imagedef fast_hist(a, b, n): k = (a >= 0) & (a < n) return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)def compute_hist(net, save_dir, dataset, layer='score', gt='label'): n_cl = net.blobs[layer].channels if save_dir: os.mkdir(save_dir) hist = np.zeros((n_cl, n_cl)) loss = 0 for idx in dataset: net.forward() hist += fast_hist(net.blobs[gt].data[0, 0].flatten(), net.blobs[layer].data[0].argmax(0).flatten(), n_cl) if save_dir: im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P') im.save(os.path.join(save_dir, idx + '.png')) # compute the loss as well loss += net.blobs['loss'].data.flat[0] return hist, loss / len(dataset)def seg_tests(solver, save_format, dataset, layer='score', gt='label'): print '>>>', datetime.now(), 'Begin seg tests' solver.test_nets[0].share_with(solver.net) do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)def do_seg_tests(net, iter, save_format, dataset, layer='score', gt='label'): n_cl = net.blobs[layer].channels if save_format: save_format = save_format.format(iter) hist, loss = compute_hist(net, save_format, dataset, layer, gt) # mean loss print '>>>', datetime.now(), 'Iteration', iter, 'loss', loss # overall accuracy acc = np.diag(hist).sum() / hist.sum() print '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc # per-class accuracy acc = np.diag(hist) / hist.sum(1) print '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc) # per-class IU iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) print '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu) freq = hist.sum(1) / hist.sum() print '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \ (freq[freq > 0] * iu[freq > 0]).sum() return hist
阅读全文
0 0
- caffe_score.py
- py
- py
- py
- py
- py
- py
- py
- Py
- bin2hex.py && hex2bin.py
- web.py (url.py)
- [py]py存放家具
- Xctf之调皮的py-py-py
- dir_size.py
- sendEmail.py
- Html.py
- web.py
- mail.py
- Tensorflow add op
- 使用 js 美化 json
- 圆角边框
- Java 访问PI 数据库:(1)安装必要软件
- 判断php变量是否为空/已定义
- caffe_score.py
- 逻辑回归模型(Logistic Regression, LR)基础
- BZOJ 1856 [Scoi2010]字符串 组合数学
- Thrift之TProtocol类体系原理及源码详细解析之紧凑协议类TCompactProtocolT
- IOS 11 下title 偏移问题,有人遇到过吗?
- 安装禅道参考一
- springmvc fastjson定制化输出
- Kubernetes 在知乎的应用
- Mac关闭WindowController,点击appicon打开