cxxnet学习笔记5
来源:互联网 发布:psv重构数据库 编辑:程序博客网 时间:2024/06/06 00:35
updater是参数更新模块,它定义了权值更新规则,其中AsyncUpdater是一个特殊的updater,用于处理异步通信和更新。该模块使用mshadow-ps完成异步通信。
updater模块包含的源码文件如下所示:
updater.h
param.h
adam_updater-inl.hpp
async_updater-inl.hpp
nag_updater-inl.hpp
sgd_updater-inl.hpp
updater_impl-inl.hpp
updater_impl.cpp
updater_impl.cu
updater.h包含其外部接口,该模块定义了两个类与几个函数,如下所示:
类IUpdater参数更新模型的接口,它定义了一系列关于参数更新过程的操作,如下所示:
注:ILayer并不负责参数更新,IUpdater采用由ILayer和权值累加的梯度值来更新权值。
class IUpdater {
public:
typedef typename layer::ILayer<xpu>::IVisitor IVisitor;
//复用layer的vistor类型,可以用来获得updater的权值。
virtual ~IUpdater(void) {}
virtual void SetStream(mshadow::Stream<xpu> *stream) = 0;
//设置内部计算流为流,参数中的stream为所用的流。
virtual voidInit(void) = 0;
//初始化,如果有响应则输出updater信息。
virtual void ApplyVisitor(IVisitor *pvisitor) = 0;
//应用visitor到updater,用于获取updater内容。
virtual void StartRound(int round) = 0;
//通知updater我们开始了在数据上的新一轮迭代,参数round表示第几轮迭代
virtual void Update(long epoch) = 0;
//更新参数。参数epoch为当前epoch值,epoch是当每一轮训练结束后mini-batches传过来的数。
virtual void Update(long epoch, mshadow::Tensor<xpu, 2> grad) = 0;
//更新参数,从外部提供梯度值。参数epoch属性不变,参数grad是为了简化接口的指向梯度值的指针,在传递梯度值之前FlatTo2D应该被调用。
virtual voidSetParam(const char *name, const char *val) = 0;
//用来设定updater参数
};
类IAsyncUpdater继承IUpdater,BeforeBackprop和AfterBackprop是异步调用函数,用户需要调用UpdateWait来等待更新过程完成。下面展示IAsyncUpdater结构:
template<typename xpu>
class IAsyncUpdater :public IUpdater<xpu> {
public:
virtual void BeforeBackprop(const std::vector<layer::Node<xpu>*> &nodes_in,const std::vector<layer::Node<xpu>*> &nodes_out) = 0;
//在BP前调用此函数,当updater想要通过其自身恢复梯度值时使用此函数,无需ILayer计算此过程。
virtual void AfterBackprop(bool do_update, long epoch) = 0;
//在BP后调用此函数,参数do_update决定是否在这个迭代器执行更新,如果更新,参数epoch为更新所用epoch值。
virtual void BeforeAllForward(void) = 0;
//当所有layer调用forwardprop时,该函数会被调用。
virtual voidUpdateWait(void) = 0;
//阻塞直至更新完成,如果没有更新或更新早已完成,该函数会直接返回。
//禁用Update功能
virtual voidUpdate(long epoch) {
utils::Error("IAsyncUpdater.Update call AfterBackprop instead");
}
virtual void Update(long epoch, mshadow::Tensor<xpu, 2> grad) {
utils::Error("IAsyncUpdater.Update call AfterBackprop instead");
}
};
根据指定类型创建updater,参数type指定updater类型,p_rnd为产生的随机数,weight是要更新的权值,wgrad为张量的梯度值,tag是weight的类型。
template<typename xpu>
IUpdater<xpu>* CreateUpdater(constchar *type,
mshadow::Random<xpu> *p_rnd,
mshadow::Tensor<xpu, 2> weight,
mshadow::Tensor<xpu, 2> wgrad,
const char *tag);
为指定layer创建多个updater,将它们推回out_updaters。参数layer_index为layer索引,device_id为异步updater所在设备id,param_server为异步updater可用的参数服务器,type指明updater类型,p_rnd指针为产生的随机数,layer_type为layer的类型,p_layer指针指向数据流出的layer对象,out_updaters容器存储输出,如果它里面已包含成员,就在容器尾部加入新的updater。
template<typename xpu>
void CreateAsyncUpdaters(int layer_index,
int device_id,
mshadow::ps::ISharedModel<xpu, real_t> *param_server,
const char *type,
mshadow::Random<xpu> *p_rnd,
layer::LayerType layer_type,
layer::ILayer<xpu> *p_layer,
std::vector<IAsyncUpdater<xpu>*> *out_updaters);
//kDataKeyStep是一常数,用于参数服务器的编码键索引。计算方式如下:
* data_key = layer_index * kDataKeyStep
* key(layer[i].bias) == i * kDataKeyStep + 1
* key(layer[i].bias) == i * kDataKeyStep + 1
static const intkDataKeyStep = 4;
//将layer索引和weight标签编码为唯一键,参数layer_index为layer索引,参数tag为weight类型
inline int EncodeDataKey(int layer_index, constchar *tag) {…}
//由参数key解码tag。
inline const char *DecodeTag(int key) {…}
updater.h的结构就如上所示。
param.h单元完成常见的参数更新操作,支持复杂的学习速率调度。只包含一个结构体UpdaterParam的定义,它表示每一个layer的潜在参数。
adam_updater-inl.hpp单元以momentum实现SGD,它包含一个模版类AdmUpdater的定义,继承于IUpdater ,该类实现所有相关操作。
async_updater-inl.hpp单元使用SGD实现异步更新,定义了一个类AsyncUpdater继承于IAsyncUpdater。
nag_updater-inl.hpp单元以momentum实现NAG,它也只定义了一个类NAGUpdater,其继承于IUpdater。
sgd_updater-inl.hpp单元以momentum实现SGD,它定义了一个结构体-clip(用于梯度划分和nan监测)和一个类SGDUpdater(继承于IUpdater)
updater_impl-inl.hpp单元一同编译所有updater的实现。该单元包括两个创建函数,分别是:CreateUpdater_和CreateAsyncUpdater_,这俩个函数根据指定类型创建指定updater。
它还定义了一个结构体CreateAsyncUpdaterVisitor,继承于IUpdater。
updater_impl.cpp实现了CreateUpdater函数和CreateAsyncUpdaters函数的CPU版本,updater_impl.cu实现了其GPU版本。
以上是该模块的讲解,后续会继续补充。该模块主要是参数更新模型,下一步我们会开始最后一个模块layer的学习。
- cxxnet学习笔记5
- cxxnet学习笔记1
- cxxnet学习笔记2
- cxxnet学习笔记3
- cxxnet学习笔记4
- cxxnet学习笔记6
- cxxnet学习笔记7
- cxxnet学习笔记78
- cxxnet学习笔记9
- 分布式机器学习框架:CXXNet
- CXXNET 安装教程
- linux下面安装cxxnet.
- Deep Learning Framework CXXNET Compilation
- cxxnet中multi-machine例子编译流程
- Windows8.1(64位)下用vs2013编译cxxnet
- JCA1.5学习笔记
- C++学习笔记(5)
- AD学习笔记5
- 多选解决方案(1)
- 第十六周 程序阅读一(1)
- 1025. 反转链表 (25)
- C++11 标准新特性: 右值引用与转移语义
- 黑马程序员——String类
- cxxnet学习笔记5
- [深入浅出Cocoa]iOS网络编程之CFNetwork
- linux下redis的安装
- 有关Windows API中wchar_t类型的函数
- 黑马程序员-----网络编程
- cocos2dx3.2 版本windows7+VS2013下环境搭建
- ElasticSearch:堆大小与swap设置
- Socket用法详解<1>
- 【Unity】方块滚动代码