(Caffe,LeNet)权值更新(七)
来源:互联网 发布:社交网络 肖恩帕克 编辑:程序博客网 时间:2024/05/19 07:08
本文地址:http://blog.csdn.net/mounty_fsc/article/details/51588773
在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习。
1 模型优化
1.1 损失函数
损失函数
在caffe中,可以分为三个阶段:
- 前向计算阶段,这个阶段计算
fW - 反向传播阶段,这个阶段计算
∇fW - 权值更新阶段,这个阶段通过
∇fW,∇r(W) 等计算ΔW 从而更新W
1.2 随机梯度下降
在lenet中,solver的类型为SGD(Stochastic gradient descent)
SGD通过以下公式对权值进行更新:
其中,
2 代码分析
2.1 ApplyUpdate
void SGDSolver<Dtype>::ApplyUpdate() { // 获取该轮迭代的学习率(learning rate) Dtype rate = GetLearningRate(); // 对每一层网络的权值进行更新 // 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数 // 每层分别有参数与偏置参数两项参数 // 因而`learnable_params_`的size为8. for (int param_id = 0; param_id < this->net_->learnable_params().size(); ++param_id) { // 归一化,iter_size为1不需要,因而lenet不需要。 // 此处的归一化内容很简单,仅仅是iter_size大于1时值再除以iter_size Normalize(param_id); // 正则化 Regularize(param_id); // 计算更新值\delta w ComputeUpdateValue(param_id, rate); } // 更新权值 this->net_->Update();}
说明:
lenet中学习参数设置可从
lenet_solver.prototxt
中查到# The base learning rate, momentum and the weight decay of the network.base_lr: 0.01momentum: 0.9weight_decay: 0.0005# The learning rate policylr_policy: "inv"gamma: 0.0001power: 0.75
获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是
inv
的策略,是一种没一轮迭代学习率都改变的策略。// The learning rate decay policy. The currently implemented learning rate // policies are as follows: // - fixed: always return base_lr. // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) // - multistep: similar to step but it allows non uniform steps defined by // stepvalue // - poly: the effective learning rate follows a polynomial decay, to be // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) // - sigmoid: the effective learning rate follows a sigmod decay // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) // // where base_lr, max_iter, gamma, step, stepvalue and power are defined // in the solver parameter protocol buffer, and iter is the current iteration.
2.2 Regularize
该函数实际执行以下公式
代码如下:
void SGDSolver<Dtype>::Regularize(int param_id) { const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_weight_decay = this->net_->params_weight_decay(); Dtype weight_decay = this->param_.weight_decay(); string regularization_type = this->param_.regularization_type(); // local_decay = 0.0005 in lenet Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; ... if (regularization_type == "L2") { // axpy means ax_plus_y. i.e., y = a*x + y caffe_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } ...}
2.3 ComputeUpdateValue
该函数实际执行以下公式
代码如下:
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_lr = this->net_->params_lr(); // momentum = 0.9 in lenet Dtype momentum = this->param_.momentum(); // local_rate = lr_mult * global_rate // lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置 Dtype local_rate = rate * net_params_lr[param_id]; // Compute the update to history, then copy it to the parameter diff. ... // axpby means ax_plus_by. i.e., y = ax + by // 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中 caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); // 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中 caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); ... }
2.4 net_->Update
实际执行以下公式:
caffe_axpy<Dtype>(count_, Dtype(-1), static_cast<const Dtype*>(diff_->cpu_data()), static_cast<Dtype*>(data_->mutable_cpu_data()));
参考文献:
[1]. http://caffe.berkeleyvision.org/tutorial/solver.html
- (Caffe,LeNet)权值更新(七)
- (Caffe,LeNet)反向传播
- (Caffe,LeNet)反向传播(六)
- (Caffe,LeNet)反向传播(六)
- caffe学习(9)LeNet在Caffe上的使用
- lenet and caffe-lenet
- (Caffe,LeNet)IDE单步调试(一)
- (Caffe,LeNet)网络训练流程(二)
- (Caffe,LeNet)初始化训练网络(三)
- (Caffe,LeNet)初始化测试网络(四)
- (Caffe,LeNet)前向计算(五)
- (Caffe,LeNet)IDE单步调试(一)
- (Caffe,LeNet)网络训练流程(二)
- python-caffe接口学习(Solving in Python with LeNet)
- windows+VS2013+CPU(only)安装caffe及训练lenet
- caffe学习笔记(2)【Training LeNet on MNIST with Caffe use CPU】
- 奔跑吧Caffe(在MNIST手写体数字集上用Caffe框架训练LeNet模型)
- 【caffe】Caffe的Python接口-官方教程-01-learning-Lenet-详细说明(含代码)
- leetcode-Remove Linked List Elements-203
- Linux下的TCP/IP编程----多播和广播的实现
- 关于项目进度的变更
- response.setHeader各种用法
- C语言之main函数
- (Caffe,LeNet)权值更新(七)
- jQuery删除和清空节点
- Android SQLite的基本操作
- POJ-1837 Balance
- C 位操作
- C#字节数组操作
- Java#Servlet规范#HTTP Protocol Parameters
- 进程处理例子
- HIVE入门