caffe Accuracy.cpp
来源:互联网 发布:热血战歌龙心升级数据 编辑:程序博客网 时间:2024/05/18 02:41
比较简单,需要注意的一点是,在训练自己的数据的时候,label应该从0开始
#include <functional>#include <utility>#include <vector>#include "caffe/loss_layers.hpp"#include "caffe/util/math_functions.hpp"namespace caffe {template <typename Dtype>void AccuracyLayer<Dtype>::LayerSetUp( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { top_k_ = this->layer_param_.accuracy_param().top_k(); has_ignore_label_ = this->layer_param_.accuracy_param().has_ignore_label(); if (has_ignore_label_) { ignore_label_ = this->layer_param_.accuracy_param().ignore_label(); }}template <typename Dtype>void AccuracyLayer<Dtype>::Reshape( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { CHECK_LE(top_k_, bottom[0]->count() / bottom[1]->count()) << "top_k must be less than or equal to the number of classes."; label_axis_ = bottom[0]->CanonicalAxisIndex(this->layer_param_.accuracy_param().axis()); outer_num_ = bottom[0]->count(0, label_axis_); inner_num_ = bottom[0]->count(label_axis_ + 1); CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count()) << "Number of labels must match number of predictions; " << "e.g., if label axis == 1 and prediction shape is (N, C, H, W), " << "label count (number of labels) must be N*H*W, " << "with integer values in {0, 1, ..., C-1}."; vector<int> top_shape(0); // Accuracy is a scalar; 0 axes. top[0]->Reshape(top_shape); if (top.size() > 1) { // Per-class accuracy is a vector; 1 axes. vector<int> top_shape_per_class(1); top_shape_per_class[0] = bottom[0]->shape(label_axis_); top[1]->Reshape(top_shape_per_class); nums_buffer_.Reshape(top_shape_per_class); }}template <typename Dtype>void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { Dtype accuracy = 0; const Dtype* bottom_data = bottom[0]->cpu_data(); const Dtype* bottom_label = bottom[1]->cpu_data(); const int dim = bottom[0]->count() / outer_num_; const int num_labels = bottom[0]->shape(label_axis_);//1000类就是1000 vector<Dtype> maxval(top_k_+1); vector<int> max_id(top_k_+1); if (top.size() > 1) { caffe_set(nums_buffer_.count(), Dtype(0), nums_buffer_.mutable_cpu_data()); caffe_set(top[1]->count(), Dtype(0), top[1]->mutable_cpu_data()); } int count = 0; for (int i = 0; i < outer_num_; ++i) { for (int j = 0; j < inner_num_; ++j) { const int label_value = static_cast<int>(bottom_label[i * inner_num_ + j]); if (has_ignore_label_ && label_value == ignore_label_) { continue; } if (top.size() > 1) ++nums_buffer_.mutable_cpu_data()[label_value]; DCHECK_GE(label_value, 0); DCHECK_LT(label_value, num_labels);//训练自己的数据,类别必须从0开始 // Top-k accuracy std::vector<std::pair<Dtype, int> > bottom_data_vector; for (int k = 0; k < num_labels; ++k) { bottom_data_vector.push_back(std::make_pair( bottom_data[i * dim + k * inner_num_ + j], k)); } std::partial_sort( bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());//排序 取top_k // check if true label is in top k predictions for (int k = 0; k < top_k_; k++) { if (bottom_data_vector[k].second == label_value) { ++accuracy; if (top.size() > 1) ++top[1]->mutable_cpu_data()[label_value]; break; } } ++count; } } // LOG(INFO) << "Accuracy: " << accuracy; top[0]->mutable_cpu_data()[0] = accuracy / count; if (top.size() > 1) { for (int i = 0; i < top[1]->count(); ++i) { top[1]->mutable_cpu_data()[i] = nums_buffer_.cpu_data()[i] == 0 ? 0 : top[1]->cpu_data()[i] / nums_buffer_.cpu_data()[i]; } } // Accuracy layer should not be used as a loss function.}INSTANTIATE_CLASS(AccuracyLayer);REGISTER_LAYER_CLASS(Accuracy);} // namespace caffe
0 0
- caffe Accuracy.cpp
- caffe accuracy 学习
- caffe中的Accuracy
- caffe 绘制loss/ accuracy曲线
- caffe绘制loss,accuracy错误
- Caffe分类accuracy过低
- caffe绘制loss,accuracy曲线
- caffe画loss accuracy曲线
- Caffe Notes: Caffe.cpp
- caffe训练数据时,accuracy 一直是0
- caffe 绘制accuracy和loss曲线
- caffe中可视化Loss和accuracy
- caffe accuracy层以及blob的梳理
- Caffe 绘制训练过程loss,accuracy曲线
- caffe 训练增加日志,画accuracy曲线
- matlab 绘制caffe accuracy与loss曲线
- caffe绘制loss和accuracy曲线
- caffe loss、accuracy等数据可视化
- 第四周项目5-用递归方法求解(1)
- JAVA SE核心 学习day01
- HelloAndroidUpDate
- 东北衰败宣告了国企城市的破产
- GLUT文档
- caffe Accuracy.cpp
- Java Vector 构造函数与增长的探究
- Leetcode 57 Insert Interval
- HDU 2841 容斥原理
- 排序算法-及其Java代码实现
- 轮换相乘的小程序
- JSP开发模式
- 2016-03-28缓存
- iOS UITableView2