caffe accuracy 学习
来源:互联网 发布:国内旅游消费总额数据 编辑:程序博客网 时间:2024/05/21 22:51
首先我们先看一下accuracy (在caffe.proto里面)的类定义
message AccuracyParameter {
// When computing accuracy, count as correct by comparing the true label to
// the top k scoring classes. By default, only compare to the top scoring
// class (i.e. argmax).
optional uint32 top_k = 1 [default = 1];//默认只取前得分最高的值位置作为标签,如果大于1,就选择值最高的top_k个,如果标签值在这top_k中,就认为识别正确。
// The "label" axis of the prediction blob, whose argmax corresponds to the
// predicted label -- may be negative to index from the end (e.g., -1 for the
// last axis). For example, if axis == 1 and the predictions are
// (N x C x H x W), the label blob is expected to contain N*H*W ground truth
// labels with integer values in {0, 1, ..., C-1}.
optional int32 axis = 2 [default = 1];
// If specified, ignore instances with the given label.
optional int32 ignore_label = 3;//忽略的标签,出现这类标签就不进行统计。
}
accuracy_layer.cpp
template <typename Dtype>
void AccuracyLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
top_k_ = this->layer_param_.accuracy_param().top_k();
has_ignore_label_ =
this->layer_param_.accuracy_param().has_ignore_label();
if (has_ignore_label_) {
ignore_label_ = this->layer_param_.accuracy_param().ignore_label();
}
}
template <typename Dtype>
void AccuracyLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
//bottom[0]是前一层的输入 bottom[1]是label的输入
CHECK_LE(top_k_, bottom[0]->count() / bottom[1]->count())
<< "top_k must be less than or equal to the number of classes.";
label_axis_ =
bottom[0]->CanonicalAxisIndex(this->layer_param_.accuracy_param().axis());
/*
inline int count(int start_axis, int end_axis) const {
CHECK_LE(start_axis, end_axis);
CHECK_GE(start_axis, 0);
CHECK_GE(end_axis, 0);
CHECK_LE(start_axis, num_axes());
CHECK_LE(end_axis, num_axes());
int count = 1;
for (int i = start_axis; i < end_axis; ++i) {
count *= shape(i);
}
return count;
}
inline int count(int start_axis) const {
return count(start_axis, num_axes());
}
*/
outer_num_ = bottom[0]->count(0, label_axis_);//如果默认的情况就是label_axis_ = 1,outer_num_ = N
inner_num_ = bottom[0]->count(label_axis_ + 1);//如果默认的情况就是label_axis_ = 1,inner_num_ = W*H
CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count())
<< "Number of labels must match number of predictions; "
<< "e.g., if label axis == 1 and prediction shape is (N, C, H, W), "
<< "label count (number of labels) must be N*H*W, "
<< "with integer values in {0, 1, ..., C-1}.";
vector<int> top_shape(0); // Accuracy is a scalar; 0 axes.??为什么这样设置,不知?
top[0]->Reshape(top_shape);
}
template <typename Dtype>
void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
Dtype accuracy = 0;
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* bottom_label = bottom[1]->cpu_data();
const int dim = bottom[0]->count() / outer_num_;
const int num_labels = bottom[0]->shape(label_axis_);
vector<Dtype> maxval(top_k_+1);
vector<int> max_id(top_k_+1);
int count = 0;
for (int i = 0; i < outer_num_; ++i) {
for (int j = 0; j < inner_num_; ++j) {
const int label_value =
static_cast<int>(bottom_label[i * inner_num_ + j]);
if (has_ignore_label_ && label_value == ignore_label_) {
continue;
}
DCHECK_GE(label_value, 0);
DCHECK_LT(label_value, num_labels);
// Top-k accuracy
std::vector<std::pair<Dtype, int> > bottom_data_vector;
//将前一层输入进来的信息进行pair,pair为值和其位置k。
for (int k = 0; k < num_labels; ++k) {
bottom_data_vector.push_back(std::make_pair(
bottom_data[i * dim + k * inner_num_ + j], k));
}
//排序
std::partial_sort(
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
// check if true label is in top k predictions
for (int k = 0; k < top_k_; k++) {
if (bottom_data_vector[k].second == label_value) {
++accuracy;
break;
}
}
++count;
}
}
// LOG(INFO) << "Accuracy: " << accuracy;
top[0]->mutable_cpu_data()[0] = accuracy / count;
// Accuracy layer should not be used as a loss function.
}
- caffe accuracy 学习
- caffe学习 绘制loss和accuracy
- caffe学习小问题(1):caffe中的Accuracy
- caffe学习小问题(1):caffe中的Accuracy
- caffe学习小问题(1):caffe中的Accuracy
- Caffe学习:使用pycaffe绘制loss、accuracy曲线
- Caffe学习系列(19): 绘制loss和accuracy曲线
- Caffe学习系列(19): 绘制loss和accuracy曲线
- Caffe学习系列: 绘制loss和accuracy曲线
- Caffe学习系列(19): 绘制loss和accuracy曲线
- caffe 学习笔记之ubuntu下绘制loss&accuracy 曲线
- Caffe学习:使用pycaffe绘制loss、accuracy曲线
- Caffe学习系列(19): 绘制loss和accuracy曲线
- caffe Accuracy.cpp
- caffe中的Accuracy
- Caffe学习:绘制loss和accuracy曲线(使用caffe工具包)
- Caffe学习:绘制loss和accuracy曲线(使用caffe的python接口)
- caffe 绘制loss/ accuracy曲线
- poj 1061
- UIButton修改文字大小问题
- 我的第一个博客
- 体育IP价值大爆发 本土赛事IP蕴含着巨大发展潜力
- Java线程(九):Condition-线程通信更高效的方式
- caffe accuracy 学习
- 奋斗吧,程序员——第四章 人生若只如初见,何事秋风悲画扇
- Wiggle Sort
- Linux——作业1
- PSR-1 基本代码规范
- Mybatis整合Spring(这篇写的很清楚所以转载了)
- C语言函数校对 符号函数sgn()
- Service详解
- Maven之Spring_boot创建表结构