opennmt聊天模型训练

来源:互联网 发布:手机如何开网店淘宝店 编辑:程序博客网 时间:2024/06/05 14:34

1)数据处理:

th preprocess.lua -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo

训练数据:src-train.txt为输入语句,tgt-train.txt为对于的回答,两者一一对应;验证数据为src-val.txt,tgt-val.txt.输入数据每个字之间有空格隔开,读取的时候以空格为间隔读取每个字.

处理时,会分别对src-train.txt处理,提取得到字典src.dict,对tgt-train.txt,提取字典tgt.dict.同时由所得的字典将对于的语料转换成为数字索引,最后将处理后的数据以及字典保存在demo-train.t7中.

同时,输入还可以包含有特征,如词性或者其他特征.

2)模型训练:

th train.lua -data data/demo-train.t7 -save_model model-demo  -gpuid 1

加载处理后的数据,进行模型训练,学习率,迭代次数等参数可以设定.

模型保存为model-demo_final.t7;

3)测试

th translate.lua -model model_final.t7 -src data/src-test.txt -output pred.txt

加载训练好的模型model-demo_final.t7,输入语句为src-test.txt,对于的回答保存在pred.txt;

从命令行输入为:

th translate_stdin.lua -model model_final.t7 -src data/src-test.txt -output pred.txt

改进:

对preprocess.lua进行改进,由于输入输出都为中文,因此不需要对输入和回答分别生成两个字典,改进为输入和回答用同一个字典,

google翻译模型训练:代码 seq2seq-baseline

训练过程;

损失函数由两部分组成, 为基本的seq2seq计算得到的损失函数, 为奖惩损失函数,

初始化所有的训练变量为[-0.04,0.04],梯度剪切0.5,

由于adam算法在训练开始的时候可以加快收敛速度,但是得到收敛结果比sgb差,因此先用adam

训练,学习率为0.0002,sgd学习率为0.5,

前1.2Msteps,保存学习率不变,之后,每200k steps降低学习率,

训练ML目标函数收敛后,转移到训练RL+ML目标函数,

dropout,只在ML的时候设置dropout=0.2 or 0.3

原创粉丝点击