tensor2tensor
来源:互联网 发布:suse linux vsftp 编辑:程序博客网 时间:2024/06/01 22:40
tensor2tensor to train all the need is attention model
参数定义:
通过 create_experiment调用train_utils.create_hparams,调用函数problem_hparams.problem_hparams,之后调用transormer,transformer调用common_hparams得到基本的模型参数,并在transformer中补充参数.
模型调用:
trainer_utils._cond_on_index调用fn(cur_idx),fn为trainer_utils.model_fn.nth_model,调用
t2t_model._with_timing.fn_with_timing调用fn,fn调用transformer.model_fn_body得到loss,logits.
训练数据生成:
函数为tensor2tensor/bin/t2t-datagen.py
对于每个任务和数据集,都分别定义了其数据生成函数,所有任务数据处理函数定义在词典_SUPPORTED_PROBLEM_GENERATORS中
例如要训练en-fr的attention is all you need模型,其数据处理函数定义为:
“wmt_enfr_tokens_32k”: ( lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15))
enfr_wordpiece_token_generator定义在wmt中,代码如下:
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size): """Instance of token generator for the WMT en->fr task.""" symbolizer_vocab = generator_utils.get_or_generate_vocab( tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size) datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag) return token_generator(data_path + ".lang1", data_path + ".lang2", symbolizer_vocab, 1)
首先generator_utils.get_or_generate_vocab函数生成词典,_ENFR_TRAIN_DATASETS 为包含输入数据的词典,token_generator函数读取训练数据,并将其转换为数字:
def token_generator(source_path, target_path, token_vocab, eos=None): eos=token_vocab['</S>'] eos_list = [] if eos is None else [eos] with tf.gfile.GFile(source_path, mode="r") as source_file: with tf.gfile.GFile(target_path, mode="r") as target_file: source, target = source_file.readline(), target_file.readline() while source and target: source_ints = word_num(source.strip('\n'),token_vocab) + eos_list#change text to number and end with eos1 target_ints = word_num(target.strip('\n'),token_vocab) + eos_list slen=max(len(source_ints),len(target_ints)) if slen>=5 and slen<=20:#control the training sentence to be [5,20] yield {"inputs": source_ints, "targets": target_ints} source, target = source_file.readline(), target_file.readline()
- tensor2tensor
- 谷歌最新开源Tensor2Tensor
- 发布Tensor2Tensor,加快深度学习研究
- 一个模型库学习所有:谷歌开源模块化深度学习系统Tensor2Tensor
- 邝斌的ACM模板(KMP算法)
- Java类和对象 详解(一)
- 有个对象已经锁了,调用这个对象的方法会被阻塞吗?
- 【其他】去掉Coding Pages的欢迎页之Hosted by Coding Pages,我的是Hexo的Next主题
- Netty 实战
- tensor2tensor
- SVM入门(十)将SVM用于多类分类
- 关于input输入框type为number型时 maxlength无效解决办法
- HDU6098 Inversion
- 正则表达式(Java)
- java NIO和IO的区别
- Qt的一些鲜为人知但是非常有用的小功能
- HDU 6069 Counting Divisors
- 《Deep Learnning Tutorial》笔记(二)