MXNet的训练入口:fit.py源码详解
来源:互联网 发布:淘宝开店运营流程图 编辑:程序博客网 时间:2024/06/05 16:38
fit.py是MXNet的fine-tune.py(参看博文:MXNet的fine-tune.py源码详解)中启动训练的入口,非常值得读一读源码。这个脚本是作者包装好的训练入口,最核心的还是Module类的fit方法(model.fit()就是Module类的对象在条用fit方法)。总的来讲,这个fit.py脚本包含训练的一些配置,导入模型,训练模型和保存模型这几步,接下来详细阐述。建议从最后的主函数fit()开始看起。
import mxnet as mximport loggingimport osimport time# 这个函数主要是和学习率相关,我们在启动训练的时候一般会添加这个参数:--lr,就是学习率,# 如果不设置的话就会采用fine-tune.py中的默认lr。lr_factor表示当你要改变lr的时候是以多大比率改变,# 举个例子,你原来lr是0.1,设置的lr_step_epochs是2,那么当你训练到epoch==2的时候,# 你的lr就会变成lr*lr_factor,这个lr_factor在fit.py脚本中默认设置为0.1。# 另外你的lr_step_epochs可以有多个值,比如(2,4,6),表示当epoch达到这些值的时候都要乘以lr_factor。def _get_lr_scheduler(args, kv): if 'lr_factor' not in args or args.lr_factor >= 1: return (args.lr, None) epoch_size = args.num_examples / args.batch_size if 'dist' in args.kv_store: epoch_size /= kv.num_workers begin_epoch = args.load_epoch if args.load_epoch else 0 step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] lr = args.lr for s in step_epochs: if begin_epoch >= s: lr *= args.lr_factor if lr != args.lr: logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch)) steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0] return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))# 导入模型,其实本质还是和fine-tune.py中的导入模型一样采用model.py脚本中的load_checkpoint函数。# 首先判断你的load_epoch参数有没有设置,没设置的话运行时候会直接报错。def _load_model(args, rank=0): if 'load_epoch' not in args or args.load_epoch is None: return (None, None, None) assert args.model_prefix is not None model_prefix = args.model_prefix if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): model_prefix += "-%d" % (rank) sym, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.load_epoch) logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) return (sym, arg_params, aux_params)#保存模型,需要你设置model_prefix这个参数。def _save_model(args, rank=0): if args.model_prefix is None: return None dst_dir = os.path.dirname(args.model_prefix) if not os.path.isdir(dst_dir): os.mkdir(dst_dir) return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % ( args.model_prefix, rank))#这部分主要是配置和训练相关的一些参数,里面的help字段是对该参数的解释,比较清晰了。def add_fit_args(parser): """ parser : argparse.ArgumentParser return a parser added with args required by fit """ train = parser.add_argument_group('Training', 'model training') train.add_argument('--network', type=str, help='the neural network to use') train.add_argument('--num-layers', type=int, help='number of layers in the neural network, required by some networks such as resnet') train.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') train.add_argument('--kv-store', type=str, default='device', help='key-value store type') train.add_argument('--num-epochs', type=int, default=100, help='max num of epochs') train.add_argument('--lr', type=float, default=0.1, help='initial learning rate') train.add_argument('--lr-factor', type=float, default=0.1, help='the ratio to reduce lr on each step') train.add_argument('--lr-step-epochs', type=str, help='the epochs to reduce the lr, e.g. 30,60') train.add_argument('--optimizer', type=str, default='sgd', help='the optimizer type') train.add_argument('--mom', type=float, default=0.9, help='momentum for sgd') train.add_argument('--wd', type=float, default=0.0001, help='weight decay for sgd') train.add_argument('--batch-size', type=int, default=128, help='the batch size') train.add_argument('--disp-batches', type=int, default=20, help='show progress for every n batches') train.add_argument('--model-prefix', type=str, help='model prefix') parser.add_argument('--monitor', dest='monitor', type=int, default=0, help='log network parameters every N iters if larger than 0') train.add_argument('--load-epoch', type=int, help='load the model on an epoch using the model-load-prefix') train.add_argument('--top-k', type=int, default=0, help='report the top-k accuracy. 0 means no report.') train.add_argument('--test-io', type=int, default=0, help='1 means test reading speed without training') return traindef fit(args, network, data_loader, **kwargs): """ train a model args : argparse returns network : the symbol definition of the nerual network data_loader : function that returns the train and val data iterators """ # kvstore# kvstore主要是解决你的梯度更新是在cpu进行还是gpu进行,这里主要调用kvstore.py脚本的create函数,# 路径是~/mxnet/python/mxnet/kvstore.py。这里的kv_store默认是‘device’,表示在GPU上计算梯度和更新权重,# 如果是要分布式训练,可以修改成’dist_device_sync‘。如果是在cpu上更新,那么kv_store就要设置为‘local’,# 不过一般不会这么做,详细可以参考kvstore.py这个脚本。 kv = mx.kvstore.create(args.kv_store) # logging#这个主要是日志细节信息,可以参考另一篇博文:[MXNet的fine-tune.py源码详解] head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) logging.info('start with arguments %s', args) # data iterators# 数据导入主要在fine-tune.py脚本中调用data.py的get_rec_iter函数时候已经做好了,# 所以这一部分做的就很少,而且test_io默认是0,那个if语句也就进不去了。 (train, val) = data_loader(args, kv) if args.test_io: tic = time.time() for i, batch in enumerate(train): for j in batch.data: j.wait_to_read() if (i+1) % args.disp_batches == 0: logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % ( i, args.disp_batches*args.batch_size/(time.time()-tic))) tic = time.time() return # load model#导入模型参数已经在fine-tune.py中做好了,这里就是简单地一个赋值。当然如果你之前导入模型没做好,这个还有个else可以导入。 if 'arg_params' in kwargs and 'aux_params' in kwargs: arg_params = kwargs['arg_params'] aux_params = kwargs['aux_params'] else: sym, arg_params, aux_params = _load_model(args, kv.rank) if sym is not None: assert sym.tojson() == network.tojson() # save model# 保存模型就是你训练结束后生成的.param文件,跳到前面的_save_model函数 checkpoint = _save_model(args, kv.rank) # devices for training#设置在cpu还是gpu上训练,因为默认的gpus参数没有值,所以你如果要在gpu上训练,需要在启动训练的时候加上类似--gpus 0这样的参数。 devs = mx.cpu() if args.gpus is None or args.gpus is '' else [ mx.gpu(int(i)) for i in args.gpus.split(',')] # learning rate#关于学习率的改变 lr, lr_scheduler = _get_lr_scheduler(args, kv) # create model# 这一步比较重要,通过mx.mod.Module函数生成model对象,注意这里的symbol就是网络结构,用fit函数的输入之一network赋值。# 如果你需要在训练的时候固定一些层的参数不更新,只更新部分层的参数,那么可以在生成这个model对象的时候加上# 类似fixed_param_names = [‘layer_name1’,‘layer_name2’]这样的参数,表示这两个参数不参与更新。 model = mx.mod.Module( context = devs, symbol = network ) lr_scheduler = lr_scheduler optimizer_params = { 'learning_rate': lr, 'momentum' : args.mom, 'wd' : args.wd, 'lr_scheduler': lr_scheduler, 'multi_precision': True}# monitor参数可以用于设置迭代多少次就显示下网络的参数情况 monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None# 这一部分主要是参数的初始化,可以看到随机初始化主要采用gaussion,不过现在一般都采用fine-tune,# 即用别人的模型的参数来初始化你的模型,随机初始化用的并不多了。 if args.network == 'alexnet': # AlexNet will not converge using Xavier initializer = mx.init.Normal() else: initializer = mx.init.Xavier( rnd_type='gaussian', factor_type="in", magnitude=2) # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),# 评价标准,默认采用准确率,top_k是表示预测的类别概率最高的前k个包含真实概率即算预测正确,常见ImageNet比赛的top_5。 # evaluation metrices eval_metrics = ['accuracy'] if args.top_k > 0: eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k)) # callbacks that run after each batch# 这个batch_end_callbacks简单讲就是显示训练到第几个epoch了,第几个batch了以及训练速度等信息。# 这里的dis_batches参数就是你在训练界面看到的多少batch后显示结果,默认是20。 batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)] if 'batch_end_callback' in kwargs: cbs = kwargs['batch_end_callback'] batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] # run# 最后是最重要的训练启动入口,这个model.fit()表示调用model对象的fit函数,# 这个model对象是前面create model时候通过mx.mod.Module生成的,详细可以查看~/mxnet/python/mxnet/model.py脚本,# 里面包含最重要的fit()函数,下一篇博客细讲这个函数。 model.fit(train, begin_epoch = args.load_epoch if args.load_epoch else 0, num_epoch = args.num_epochs, eval_data = val, eval_metric = eval_metrics, kvstore = kv, optimizer = args.optimizer, optimizer_params = optimizer_params, initializer = initializer, arg_params = arg_params, aux_params = aux_params, batch_end_callback = batch_end_callbacks, epoch_end_callback = checkpoint, allow_missing = True, monitor = monitor)
阅读全文
0 0
- MXNet的训练入口:fit.py源码详解
- MXNet的预训练:fine-tune.py源码详解
- MXNet的数据读取:data.py源码详解
- MXNet的训练基础脚本:base_module.py
- MXNet的训练实现脚本:module.py
- MXNet的数据预处理:mxnet.image.CreateAugmenter源码详解
- mxnet中im2rec.py的小问题
- Mxnet(3)-SSD训练自己的数据
- Mxnet(4)-fcn训练自己的数据
- mxnet im2re.py
- mxnet多层感知机训练MNIST数据集详解【转】
- MXNet的源码编译过程总结
- 【问题 解决】mxnet训练mnist数据集的Train_accuracy很小
- mxnet 使用自己的图片数据训练CNN模型
- MXNet:训练自己的数据并做预测
- Mxnet训练自己的数据集并测试
- mxnet利用下载好的mnist数据训练cnn
- wsgi.py的详解
- $(document).ready()笔记
- [乱搞]斐波那契数列与gcd之间一个有趣的定理
- Android图片处理框架之Fresco学习使用
- java filter过滤器的配置
- 面向对象和面向过程的区别
- MXNet的训练入口:fit.py源码详解
- Android 动画总结
- c++ 学习 内存四区
- PAT--1050. String Subtraction
- Spring Cloud构建微服务架构(七)消息总线
- 文章标题
- shell for用法
- 基于jdk的网络编程和使用Netty的比较
- Java反射机制应用实践