深度学习Matlab工具箱代码注释——cnntrain.m
来源:互联网 发布:部落冲突黑油钻井数据 编辑:程序博客网 时间:2024/06/05 14:19
%%=========================================================================%函数名称:cnntrain()%输入参数:net,神经网络;x,训练数据矩阵;y,训练数据的标签矩阵;opts,神经网络的相关训练参数%输出参数:net,训练完成的卷积神经网络%算法流程:1)将样本打乱,随机选择进行训练;% 2)取出样本,通过cnnff2()函数计算当前网络权值和网络输入下网络的输出% 3)通过BP算法计算误差对网络权值的导数% 4)得到误差对权值的导数后,就通过权值更新方法去更新权值%注意事项:1)使用BP算法计算梯度%%=========================================================================function net = cnntrain(net, x, y, opts)m = size(x, 3); %m保存的是训练样本个数disp(['样本总个数=' num2str(m)]);numbatches = m / opts.batchsize; %numbatches表示每次迭代中所选取的训练样本数if rem(numbatches, 1) ~= 0 %如果numbatches不是整数,则程序发生错误 error('numbatches not integer');end%%=====================================================================%主要功能:CNN网络的迭代训练%实现步骤:1)通过randperm()函数将原来的样本顺序打乱,再挑出一些样本来进行训练% 2)取出样本,通过cnnff2()函数计算当前网络权值和网络输入下网络的输出% 3)通过BP算法计算误差对网络权值的导数% 4)得到误差对权值的导数后,就通过权值更新方法去更新权值%注意事项:1)P = randperm(N),返回[1, N]之间所有整数的一个随机的序列,相当于把原来的样本排列打乱,% 再挑出一些样本来训练% 2)采用累积误差的计算方式来评估当前网络性能,即当前误差 = 以前误差 * 0.99 + 本次误差 * 0.01% 使得网络尽可能收敛到全局最优%%=====================================================================net.rL = []; %代价函数值,也就是误差值for i = 1 : opts.numepochs %对于每次迭代 disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]); tic; %使用tic和toc来统计程序运行时间 %%%%%%%%%%%%%%%%%%%%取出打乱顺序后的batchsize个样本和对应的标签 %%%%%%%%%%%%%%%%%%%% kk = randperm(m); for l = 1 : numbatches batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize)); batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize)); %%%%%%%%%%%%%%%%%%%%在当前的网络权值和网络输入下计算网络的输出(特征向量)%%%%%%%%%%%%%%%%%%%% net = cnnff(net, batch_x); %卷积神经网络的前馈运算 %%%%%%%%%%%%%%%%%%%%通过对应的样本标签用bp算法来得到误差对网络权值的导数%%%%%%%%%%%%%%%%%%%% net = cnnbp(net, batch_y); %卷积神经网络的BP算法 %%%%%%%%%%%%%%%%%%%%通过权值更新方法去更新权值%%%%%%%%%%%%%%%%%%%% net = cnnapplygrads(net, opts); if isempty(net.rL) net.rL(1) = net.L; %代价函数值,也就是均方误差值 ,在cnnbp.m中计算初始值 net.L = 1/2* sum(net.e(:) .^ 2) / size(net.e, 2); end net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L; %采用累积的方式计算累积误差 end toc;endend
3 0
- 深度学习Matlab工具箱代码注释——cnntrain.m
- 深度学习Matlab工具箱代码注释——cnntrain.m
- 深度学习Matlab工具箱代码注释——cnntrain.m
- 深度学习Matlab工具箱代码注释——MnistTest.m
- 深度学习Matlab工具箱代码注释——cnnsetup.m
- 深度学习Matlab工具箱代码注释——cnnff.m
- 深度学习Matlab工具箱代码注释——cnnbp.m
- 深度学习Matlab工具箱代码注释——cnnapplygrads.m
- 深度学习Matlab工具箱代码注释——cnnsetup.m
- 深度学习Matlab工具箱代码注释——cnnsetup.m
- 深度学习Matlab工具箱代码注释——cnnff.m
- 深度学习Matlab工具箱代码注释——cnnbp.m
- 深度学习Matlab工具箱代码注释——cnnapplygrads.m
- 深度学习Matlab工具箱代码注释——MnistTest.m
- 深度学习Matlab工具箱代码注释
- Matlab深度学习笔记——深度学习工具箱说明
- 深度学习Matlab工具箱代码详解
- 深度学习Matlab工具箱代码详解
- 二叉树的中序遍历
- 在恩典中生活
- Oracle 11g安装过程出现“未找到文件”
- Android 位置服务——BaiduMap的使用
- java中double保留小数点
- 深度学习Matlab工具箱代码注释——cnntrain.m
- HDU 1248 寒冰王座
- 关于形如--error LNK2005: xxx 已经在 msvcrtd.lib ( MSVCR90D.dll ) 中定义--的问题分析解决
- SONY 系列手机 Android 5.1 系统 Root 方法
- js入门(三)——document对象
- maven学习总结(七)——eclipse中使用Maven创建Web项目
- the variables of python
- hadoop2.7.1单机版安装部署
- iOS详解 GCD 串行队列并行队列